From 1476c64f0bb340d2bc610527669f6ef74d2e4eae Mon Sep 17 00:00:00 2001 From: Quentin Delfosse Date: Sat, 2 May 2026 16:16:45 +0200 Subject: [PATCH 01/28] More tennis mods --- src/jaxatari/core.py | 3 +- .../mods/gravitar/gravitar_mod_plugins.py | 28 ++ src/jaxatari/games/mods/gravitar_mods.py | 6 + .../mods/mspacman/mspacman_mod_plugins.py | 14 +- src/jaxatari/games/mods/mspacman_mods.py | 2 + .../games/mods/phoenix/phoenix_mod_plugins.py | 48 +++- src/jaxatari/games/mods/phoenix_mods.py | 14 +- .../games/mods/tennis/tennis_mod_plugins.py | 254 +++++++++++++++++- src/jaxatari/games/mods/tennis_mods.py | 15 +- 9 files changed, 375 insertions(+), 9 deletions(-) diff --git a/src/jaxatari/core.py b/src/jaxatari/core.py index 6bf05908d..70a533848 100644 --- a/src/jaxatari/core.py +++ b/src/jaxatari/core.py @@ -100,7 +100,8 @@ def _warn_deprecated_obs_to_flat_array(env: JaxEnvironment) -> None: "enduro": "jaxatari.games.mods.enduro_mods.EnduroEnvMod", "qbert": "jaxatari.games.mods.qbert_mods.QbertEnvMod", "mspacman": "jaxatari.games.mods.mspacman_mods.MsPacmanEnvMod", - "beamrider": "jaxatari.games.mods.beamrider_mods.BeamRiderEnvMod" + "beamrider": "jaxatari.games.mods.beamrider_mods.BeamRiderEnvMod", + "venture": "jaxatari.games.mods.venture_mods.VentureEnvMod" } diff --git a/src/jaxatari/games/mods/gravitar/gravitar_mod_plugins.py b/src/jaxatari/games/mods/gravitar/gravitar_mod_plugins.py index 629c0a35e..b428e2f87 100644 --- a/src/jaxatari/games/mods/gravitar/gravitar_mod_plugins.py +++ b/src/jaxatari/games/mods/gravitar/gravitar_mod_plugins.py @@ -81,3 +81,31 @@ class HighSpeedMod(JaxAtariInternalModPlugin): "THRUST_POWER": 0.075, "MAX_SPEED": 6.0, } + + +class InfiniteFuelMod(JaxAtariInternalModPlugin): + """Disable fuel consumption.""" + + constants_overrides = { + "FUEL_CONSUME_THRUST": 0.0, + "FUEL_CONSUME_SHIELD_TRACTOR": 0.0, + } + + +class SlowEnemiesMod(JaxAtariInternalModPlugin): + """Decrease the movement speed of saucers and bullets.""" + + constants_overrides = { + "SAUCER_SPEED_MAP": 0.09, + "SAUCER_SPEED_ARENA": 0.18, + "SAUCER_BULLET_SPEED": 1.0, + "ENEMY_BULLET_SPEED": 0.65, + } + + +class LongRangeTractorMod(JaxAtariInternalModPlugin): + """Increase the range of the tractor beam.""" + + constants_overrides = { + "TRACTOR_BEAM_RANGE": 50.0, + } diff --git a/src/jaxatari/games/mods/gravitar_mods.py b/src/jaxatari/games/mods/gravitar_mods.py index 997df18e7..5e3962ae2 100644 --- a/src/jaxatari/games/mods/gravitar_mods.py +++ b/src/jaxatari/games/mods/gravitar_mods.py @@ -8,6 +8,9 @@ ValuableReactorMod, AntiGravityMod, HighSpeedMod, + InfiniteFuelMod, + SlowEnemiesMod, + LongRangeTractorMod, ) @@ -23,6 +26,9 @@ class GravitarEnvMod(JaxAtariModController): "valuable_reactor": ValuableReactorMod, "anti_gravity": AntiGravityMod, "high_speed": HighSpeedMod, + "infinite_fuel": InfiniteFuelMod, + "slow_enemies": SlowEnemiesMod, + "long_range_tractor": LongRangeTractorMod, } def __init__( diff --git a/src/jaxatari/games/mods/mspacman/mspacman_mod_plugins.py b/src/jaxatari/games/mods/mspacman/mspacman_mod_plugins.py index a0ec9ba51..0d4120634 100644 --- a/src/jaxatari/games/mods/mspacman/mspacman_mod_plugins.py +++ b/src/jaxatari/games/mods/mspacman/mspacman_mod_plugins.py @@ -1,9 +1,21 @@ import jax import jax.numpy as jnp from functools import partial -from jaxatari.modification import JaxAtariPostStepModPlugin +from jaxatari.modification import JaxAtariPostStepModPlugin, JaxAtariInternalModPlugin from jaxatari.games.jax_mspacman import JaxPacman, GhostMode, reset_game +class FruitGhostBonusMod(JaxAtariInternalModPlugin): + """ + Mod that deactivates points for pellets and power pellets, + but multiplies rewards for eating ghosts and fruits by 4. + """ + constants_overrides = { + "PELLET_POINTS": 0, + "POWER_PELLET_POINTS": 0, + "FRUIT_REWARDS": jnp.array([400, 800, 2000, 2800, 4000, 8000, 20000]), + "EAT_GHOSTS_BASE_POINTS": 800, + } + class CagedGhostsMod(JaxAtariPostStepModPlugin): def _jail_position(self, dtype): return jnp.array(self._env.consts.JAIL_POSITION, dtype=dtype) diff --git a/src/jaxatari/games/mods/mspacman_mods.py b/src/jaxatari/games/mods/mspacman_mods.py index ac7ea3028..337038809 100644 --- a/src/jaxatari/games/mods/mspacman_mods.py +++ b/src/jaxatari/games/mods/mspacman_mods.py @@ -3,6 +3,7 @@ from jaxatari.games.mods.mspacman.mspacman_mod_plugins import ( CagedGhostsMod, ConstantFruitsMod, + FruitGhostBonusMod, SetMaze1Mod, SetMaze2Mod, SetMaze3Mod, @@ -20,6 +21,7 @@ class MsPacmanEnvMod(JaxAtariModController): REGISTRY = { "caged_ghosts": CagedGhostsMod, "constant_fruits": ConstantFruitsMod, + "fruit_ghost_bonus": FruitGhostBonusMod, "set_maze_1": SetMaze1Mod, "set_maze_2": SetMaze2Mod, "set_maze_3": SetMaze3Mod, diff --git a/src/jaxatari/games/mods/phoenix/phoenix_mod_plugins.py b/src/jaxatari/games/mods/phoenix/phoenix_mod_plugins.py index 11fcdfe05..8b3cfe34f 100644 --- a/src/jaxatari/games/mods/phoenix/phoenix_mod_plugins.py +++ b/src/jaxatari/games/mods/phoenix/phoenix_mod_plugins.py @@ -1,4 +1,6 @@ -from jaxatari.modification import JaxAtariInternalModPlugin +from jaxatari.modification import JaxAtariInternalModPlugin, JaxAtariPostStepModPlugin +import jax.numpy as jnp +from jaxatari.games.jax_phoenix import PhoenixState class BossLateMissilesMod(JaxAtariInternalModPlugin): @@ -11,3 +13,47 @@ class BossLateMissilesMod(JaxAtariInternalModPlugin): "BOSS_PROJECTILE_RENDER_DELAY_PX": 8, } + +class InfiniteLivesMod(JaxAtariInternalModPlugin): + """ + Set player lives to 99. + """ + constants_overrides = { + "PLAYER_LIVES": 99, + } + + +class FastPlayerMod(JaxAtariInternalModPlugin): + """ + Increases player movement speed. + """ + constants_overrides = { + "PLAYER_STEP_SIZE": 2, + } + + +class InvinciblePlayerMod(JaxAtariPostStepModPlugin): + """ + Player is always invincible. + """ + def run(self, prev_state: PhoenixState, new_state: PhoenixState) -> PhoenixState: + return new_state.replace(invincibility=jnp.array(True)) + + +class FastEnemyBulletsMod(JaxAtariInternalModPlugin): + """ + Increases speed of enemy projectiles. + """ + constants_overrides = { + "ENEMY_PROJECTILE_SPEED": 4, + } + + +class NoAbilityCooldownMod(JaxAtariInternalModPlugin): + """ + Removes cooldown for the special ability (shield). + """ + constants_overrides = { + "ABILITY_COOLDOWN": 0, + } + diff --git a/src/jaxatari/games/mods/phoenix_mods.py b/src/jaxatari/games/mods/phoenix_mods.py index ab57acef7..de588c7cd 100644 --- a/src/jaxatari/games/mods/phoenix_mods.py +++ b/src/jaxatari/games/mods/phoenix_mods.py @@ -1,5 +1,12 @@ from jaxatari.modification import JaxAtariModController -from jaxatari.games.mods.phoenix.phoenix_mod_plugins import BossLateMissilesMod +from jaxatari.games.mods.phoenix.phoenix_mod_plugins import ( + BossLateMissilesMod, + InfiniteLivesMod, + FastPlayerMod, + InvinciblePlayerMod, + FastEnemyBulletsMod, + NoAbilityCooldownMod, +) class PhoenixEnvMod(JaxAtariModController): @@ -9,6 +16,11 @@ class PhoenixEnvMod(JaxAtariModController): REGISTRY = { "boss_late_missiles": BossLateMissilesMod, + "infinite_lives": InfiniteLivesMod, + "fast_player": FastPlayerMod, + "invincible_player": InvinciblePlayerMod, + "fast_enemy_bullets": FastEnemyBulletsMod, + "no_ability_cooldown": NoAbilityCooldownMod, } def __init__( diff --git a/src/jaxatari/games/mods/tennis/tennis_mod_plugins.py b/src/jaxatari/games/mods/tennis/tennis_mod_plugins.py index 24f4b0609..6bea8faf7 100644 --- a/src/jaxatari/games/mods/tennis/tennis_mod_plugins.py +++ b/src/jaxatari/games/mods/tennis/tennis_mod_plugins.py @@ -1,12 +1,10 @@ import functools from typing import Any, Dict, Tuple, Union - - import chex import jax import jax.numpy as jnp from jax import lax -from jaxatari.games.jax_tennis import TennisState +from jaxatari.games.jax_tennis import TennisState, EnemyState from jaxatari.modification import JaxAtariPostStepModPlugin, JaxAtariInternalModPlugin @@ -70,4 +68,252 @@ def make_random(self, prev_state: TennisState, state: TennisState) -> TennisStat ) def run(self, prev_state: TennisState, new_state: TennisState) -> TennisState: - return self.make_random(prev_state, new_state) \ No newline at end of file + return self.make_random(prev_state, new_state) + + +class FastPlayerMod(JaxAtariPostStepModPlugin): + """ + Increases player walk speed. + """ + def after_reset(self, obs, state: TennisState) -> Tuple[Any, TennisState]: + return obs, state.replace(player_state=state.player_state.replace(player_walk_speed=jnp.array(2.0))) + + def run(self, prev_state: TennisState, new_state: TennisState) -> TennisState: + # Also ensure it stays fast if something else resets it + return new_state.replace(player_state=new_state.player_state.replace(player_walk_speed=jnp.array(2.0))) + + +class SuperGravityMod(JaxAtariInternalModPlugin): + """ + Increases ball gravity. + """ + constants_overrides = { + "BALL_GRAVITY_PER_FRAME": 2.2, + } + + +class LazyEnemyMod(JaxAtariInternalModPlugin): + """ + Enemy moves slower. + """ + @functools.partial(jax.jit, static_argnums=(0,)) + def _enemy_step(self, state: TennisState) -> EnemyState: + # Re-implementation of enemy step with slower movement + enemy_x_hit_point = state.enemy_state.enemy_x + self._env.consts.PLAYER_WIDTH / 2 + player_x_hit_point = state.player_state.player_x + self._env.consts.PLAYER_WIDTH / 2 + ball_tracking_tolerance = 1 + x_tracking_tolerance = 2 + + def move_x_to_middle(): + middle_step_x = jnp.where(jnp.less_equal(state.enemy_state.enemy_x, self._env.consts.GAME_MIDDLE_HORIZONTAL), + state.enemy_state.enemy_x + 0.5, + state.enemy_state.enemy_x - 0.5) + return jnp.where(jnp.abs(state.enemy_state.enemy_x - self._env.consts.GAME_MIDDLE_HORIZONTAL) > 0.5, middle_step_x, + state.enemy_state.enemy_x) + + def track_ball_x(): + enemy_aiming_x_offset = jnp.where( + player_x_hit_point < self._env.consts.FRAME_WIDTH / 2, + 5, + -15 + ) + diff = state.ball_state.ball_x - (enemy_x_hit_point + enemy_aiming_x_offset) + + # move right if ball is sufficiently to the right + new_enemy_x = jnp.where( + diff > x_tracking_tolerance, + state.enemy_state.enemy_x + 0.5, + state.enemy_state.enemy_x + ) + + # move left if ball is sufficiently to the left + new_enemy_x = jnp.where( + diff < -x_tracking_tolerance, + state.enemy_state.enemy_x - 0.5, + new_enemy_x + ) + return new_enemy_x + + new_enemy_x = jax.lax.cond(state.ball_state.last_hit == 1, move_x_to_middle, track_ball_x) + + cur_walking_direction = jnp.where( + new_enemy_x - state.enemy_state.enemy_x < 0, + -1, + state.enemy_state.prev_walking_direction + ) + cur_walking_direction = jnp.where( + new_enemy_x - state.enemy_state.enemy_x > 0, + 1, + cur_walking_direction + ) + + should_perform_direction_change = jnp.logical_or( + jnp.abs((enemy_x_hit_point) - state.ball_state.ball_x) >= ball_tracking_tolerance, + state.ball_state.last_hit == 1 + ) + new_enemy_x = jnp.where(should_perform_direction_change, new_enemy_x, state.enemy_state.enemy_x) + + def enemy_y_step(): + # Enemy moves slower in Y as well + y_speed = 0.5 + state_after_y = jax.lax.cond( + state.ball_state.last_hit == 1, + lambda _: EnemyState(state.enemy_state.enemy_x, + jnp.where(jnp.logical_and(state.enemy_state.enemy_y != self._env.consts.PLAYER_Y_LOWER_BOUND_TOP, + state.enemy_state.enemy_y != self._env.consts.PLAYER_Y_UPPER_BOUND_BOTTOM), + state.enemy_state.enemy_y - state.player_state.player_field * y_speed, + state.enemy_state.enemy_y + ), state.enemy_state.prev_walking_direction, + state.enemy_state.enemy_direction, jnp.array(1)), + lambda _: EnemyState(state.enemy_state.enemy_x, + jnp.where(state.player_state.player_field == 1, + jnp.clip(state.enemy_state.enemy_y + + state.player_state.player_field * state.enemy_state.y_movement_direction * y_speed, + self._env.consts.PLAYER_Y_UPPER_BOUND_BOTTOM, + self._env.consts.PLAYER_Y_LOWER_BOUND_BOTTOM), + jnp.clip(state.enemy_state.enemy_y + + state.player_state.player_field * state.enemy_state.y_movement_direction * y_speed, + self._env.consts.PLAYER_Y_UPPER_BOUND_TOP, + self._env.consts.PLAYER_Y_LOWER_BOUND_TOP)), + state.enemy_state.prev_walking_direction, + state.enemy_state.enemy_direction, + jnp.where(jnp.logical_or(state.enemy_state.enemy_y == self._env.consts.PLAYER_Y_UPPER_BOUND_TOP, + state.enemy_state.enemy_y == self._env.consts.PLAYER_Y_LOWER_BOUND_BOTTOM), + jnp.array(-1), + state.enemy_state.y_movement_direction)), operand=None) + + return jax.lax.cond( + state.game_state.is_serving, + lambda _: state.enemy_state, + lambda _: state_after_y, + operand=None + ) + + enemy_state_after_y_step = enemy_y_step() + new_enemy_direction = jnp.where(state.enemy_state.enemy_x > state.ball_state.ball_x, -1, state.enemy_state.enemy_direction) + new_enemy_direction = jnp.where(state.enemy_state.enemy_x < state.ball_state.ball_x, 1, new_enemy_direction) + + return EnemyState( + new_enemy_x, + enemy_state_after_y_step.enemy_y, + cur_walking_direction, + new_enemy_direction, + enemy_state_after_y_step.y_movement_direction + ) + + +class HighBounceMod(JaxAtariInternalModPlugin): + """ + Increases ball bounce velocity. + """ + constants_overrides = { + "BALL_SERVING_BOUNCE_VELOCITY_BASE": 30.0, + } + +class FastEnemyMod(JaxAtariInternalModPlugin): + """ + Enemy moves faster. + """ + @functools.partial(jax.jit, static_argnums=(0,)) + def _enemy_step(self, state: TennisState) -> EnemyState: + # Re-implementation of enemy step with faster movement + enemy_x_hit_point = state.enemy_state.enemy_x + self._env.consts.PLAYER_WIDTH / 2 + player_x_hit_point = state.player_state.player_x + self._env.consts.PLAYER_WIDTH / 2 + ball_tracking_tolerance = 1 + x_tracking_tolerance = 2 + + def move_x_to_middle(): + middle_step_x = jnp.where(jnp.less_equal(state.enemy_state.enemy_x, self._env.consts.GAME_MIDDLE_HORIZONTAL), + state.enemy_state.enemy_x + 2, + state.enemy_state.enemy_x - 2) + return jnp.where(jnp.abs(state.enemy_state.enemy_x - self._env.consts.GAME_MIDDLE_HORIZONTAL) > 2, middle_step_x, + state.enemy_state.enemy_x) + + def track_ball_x(): + enemy_aiming_x_offset = jnp.where( + player_x_hit_point < self._env.consts.FRAME_WIDTH / 2, + 5, + -15 + ) + diff = state.ball_state.ball_x - (enemy_x_hit_point + enemy_aiming_x_offset) + + # move right if ball is sufficiently to the right + new_enemy_x = jnp.where( + diff > x_tracking_tolerance, + state.enemy_state.enemy_x + 2, + state.enemy_state.enemy_x + ) + + # move left if ball is sufficiently to the left + new_enemy_x = jnp.where( + diff < -x_tracking_tolerance, + state.enemy_state.enemy_x - 2, + new_enemy_x + ) + return new_enemy_x + + new_enemy_x = jax.lax.cond(state.ball_state.last_hit == 1, move_x_to_middle, track_ball_x) + + cur_walking_direction = jnp.where( + new_enemy_x - state.enemy_state.enemy_x < 0, + -1, + state.enemy_state.prev_walking_direction + ) + cur_walking_direction = jnp.where( + new_enemy_x - state.enemy_state.enemy_x > 0, + 1, + cur_walking_direction + ) + + should_perform_direction_change = jnp.logical_or( + jnp.abs((enemy_x_hit_point) - state.ball_state.ball_x) >= ball_tracking_tolerance, + state.ball_state.last_hit == 1 + ) + new_enemy_x = jnp.where(should_perform_direction_change, new_enemy_x, state.enemy_state.enemy_x) + + def enemy_y_step(): + y_speed = 2.0 + state_after_y = jax.lax.cond( + state.ball_state.last_hit == 1, + lambda _: EnemyState(state.enemy_state.enemy_x, + jnp.where(jnp.logical_and(state.enemy_state.enemy_y != self._env.consts.PLAYER_Y_LOWER_BOUND_TOP, + state.enemy_state.enemy_y != self._env.consts.PLAYER_Y_UPPER_BOUND_BOTTOM), + state.enemy_state.enemy_y - state.player_state.player_field * y_speed, + state.enemy_state.enemy_y + ), state.enemy_state.prev_walking_direction, + state.enemy_state.enemy_direction, jnp.array(1)), + lambda _: EnemyState(state.enemy_state.enemy_x, + jnp.where(state.player_state.player_field == 1, + jnp.clip(state.enemy_state.enemy_y + + state.player_state.player_field * state.enemy_state.y_movement_direction * y_speed, + self._env.consts.PLAYER_Y_UPPER_BOUND_BOTTOM, + self._env.consts.PLAYER_Y_LOWER_BOUND_BOTTOM), + jnp.clip(state.enemy_state.enemy_y + + state.player_state.player_field * state.enemy_state.y_movement_direction * y_speed, + self._env.consts.PLAYER_Y_UPPER_BOUND_TOP, + self._env.consts.PLAYER_Y_LOWER_BOUND_TOP)), + state.enemy_state.prev_walking_direction, + state.enemy_state.enemy_direction, + jnp.where(jnp.logical_or(state.enemy_state.enemy_y == self._env.consts.PLAYER_Y_UPPER_BOUND_TOP, + state.enemy_state.enemy_y == self._env.consts.PLAYER_Y_LOWER_BOUND_BOTTOM), + jnp.array(-1), + state.enemy_state.y_movement_direction)), operand=None) + + return jax.lax.cond( + state.game_state.is_serving, + lambda _: state.enemy_state, + lambda _: state_after_y, + operand=None + ) + + enemy_state_after_y_step = enemy_y_step() + new_enemy_direction = jnp.where(state.enemy_state.enemy_x > state.ball_state.ball_x, -1, state.enemy_state.enemy_direction) + new_enemy_direction = jnp.where(state.enemy_state.enemy_x < state.ball_state.ball_x, 1, new_enemy_direction) + + return EnemyState( + new_enemy_x, + enemy_state_after_y_step.enemy_y, + cur_walking_direction, + new_enemy_direction, + enemy_state_after_y_step.y_movement_direction + ) diff --git a/src/jaxatari/games/mods/tennis_mods.py b/src/jaxatari/games/mods/tennis_mods.py index 568a10210..ab031d491 100644 --- a/src/jaxatari/games/mods/tennis_mods.py +++ b/src/jaxatari/games/mods/tennis_mods.py @@ -1,6 +1,14 @@ import os from jaxatari.modification import JaxAtariModController -from jaxatari.games.mods.tennis.tennis_mod_plugins import RandomWalkSpeedWrapper, RandomBallSpeedWrapper +from jaxatari.games.mods.tennis.tennis_mod_plugins import ( + RandomBallSpeedWrapper, + RandomWalkSpeedWrapper, + FastPlayerMod, + SuperGravityMod, + LazyEnemyMod, + HighBounceMod, + FastEnemyMod, +) class TennisEnvMod(JaxAtariModController): """ @@ -11,6 +19,11 @@ class TennisEnvMod(JaxAtariModController): REGISTRY = { "random_ball_speed": RandomBallSpeedWrapper, "random_walk_speed": RandomWalkSpeedWrapper, + "fast_player": FastPlayerMod, + "super_gravity": SuperGravityMod, + "lazy_enemy": LazyEnemyMod, + "high_bounce": HighBounceMod, + "fast_enemy": FastEnemyMod, } _mod_sprite_dir = os.path.join(os.path.dirname(__file__), "tennis", "sprites") From ae0e1383d7ecb6f3ccca0eb823032f5dadbd0a38 Mon Sep 17 00:00:00 2001 From: Quentin Delfosse Date: Sat, 2 May 2026 16:28:56 +0200 Subject: [PATCH 02/28] Venture mods --- .../games/mods/venture/venture_mod_plugins.py | 58 +++++++++++++++++++ src/jaxatari/games/mods/venture_mods.py | 37 ++++++++++++ 2 files changed, 95 insertions(+) create mode 100644 src/jaxatari/games/mods/venture/venture_mod_plugins.py create mode 100644 src/jaxatari/games/mods/venture_mods.py diff --git a/src/jaxatari/games/mods/venture/venture_mod_plugins.py b/src/jaxatari/games/mods/venture/venture_mod_plugins.py new file mode 100644 index 000000000..ef0f225c5 --- /dev/null +++ b/src/jaxatari/games/mods/venture/venture_mod_plugins.py @@ -0,0 +1,58 @@ +from jaxatari.modification import JaxAtariInternalModPlugin +import jax.numpy as jnp + + +class FastWinkyMod(JaxAtariInternalModPlugin): + """Increase Winky's movement speed.""" + + constants_overrides = { + "PLAYER_SPEED": 2.0, + } + + +class SlowMonstersMod(JaxAtariInternalModPlugin): + """Decrease the movement speed of all monsters, including the hallway chaser.""" + + constants_overrides = { + "MONSTER_SPEEDS": jnp.array([0.5, 0.75, 1.0, 1.25], dtype=jnp.float32), + "CHASER_SPEED": 0.2, + } + + +class WealthyVentureMod(JaxAtariInternalModPlugin): + """Significantly increase the points awarded for collecting treasures.""" + + constants_overrides = { + "CHEST_SCORE": 1000, + } + + +class PatientChaserMod(JaxAtariInternalModPlugin): + """Increase the time before the hallway chaser appears in a room.""" + + constants_overrides = { + "CHASER_SPAWN_FRAMES": 10000, + } + + +class FastArrowsMod(JaxAtariInternalModPlugin): + """Increase the speed of Winky's arrows.""" + + constants_overrides = { + "PROJECTILE_SPEED": 4.0, + } + + +class LongRangeArrowsMod(JaxAtariInternalModPlugin): + """Increase the distance arrows travel before disappearing.""" + + constants_overrides = { + "PROJECTILE_LIFETIME_FRAMES": 60, + } + + +class GodModeMod(JaxAtariInternalModPlugin): + """Winky is immune to collisions with hazards.""" + + def _check_player_hazard_collision(self, player_state, monster_state, chaser_state, laser_state, current_level, world_level): + return jnp.array(False) diff --git a/src/jaxatari/games/mods/venture_mods.py b/src/jaxatari/games/mods/venture_mods.py new file mode 100644 index 000000000..668acadfa --- /dev/null +++ b/src/jaxatari/games/mods/venture_mods.py @@ -0,0 +1,37 @@ +from jaxatari.modification import JaxAtariModController +from jaxatari.games.mods.venture.venture_mod_plugins import ( + FastWinkyMod, + SlowMonstersMod, + WealthyVentureMod, + PatientChaserMod, + FastArrowsMod, + LongRangeArrowsMod, + GodModeMod, +) + + +class VentureEnvMod(JaxAtariModController): + """Game-specific Mod Controller for Venture.""" + + REGISTRY = { + "fast_winky": FastWinkyMod, + "slow_monsters": SlowMonstersMod, + "wealthy_venture": WealthyVentureMod, + "patient_chaser": PatientChaserMod, + "fast_arrows": FastArrowsMod, + "long_range_arrows": LongRangeArrowsMod, + "god_mode": GodModeMod, + } + + def __init__( + self, + env, + mods_config: list = [], + allow_conflicts: bool = False, + ): + super().__init__( + env=env, + mods_config=mods_config, + allow_conflicts=allow_conflicts, + registry=self.REGISTRY, + ) From 9248f44574e4631e6d9cd635a5c9b8a0db7b10b5 Mon Sep 17 00:00:00 2001 From: Quentin Delfosse Date: Sat, 2 May 2026 16:39:25 +0200 Subject: [PATCH 03/28] Script list mods --- scripts/list_mods.py | 55 ++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 55 insertions(+) create mode 100644 scripts/list_mods.py diff --git a/scripts/list_mods.py b/scripts/list_mods.py new file mode 100644 index 000000000..d4dce1781 --- /dev/null +++ b/scripts/list_mods.py @@ -0,0 +1,55 @@ +import importlib +import sys +import os +import warnings +import textwrap + +# Suppress noisy warnings +warnings.filterwarnings("ignore", category=UserWarning) +os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "hide" + +# Add src to path so we can import jaxatari if it's not installed +sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "src"))) + +from jaxatari.core import GAME_MODULES, MOD_MODULES + +def _load_from_string(path: str): + """Dynamically import an attribute from a module path string.""" + module_path, attr_name = path.rsplit(".", 1) + module = importlib.import_module(module_path) + return getattr(module, attr_name) + +def list_mods(): + all_games = sorted(GAME_MODULES.keys()) + + print(f"{'GAME':<20} | MODS") + print("-" * 80) + + for game in all_games: + mods = [] + if game in MOD_MODULES: + try: + controller_class = _load_from_string(MOD_MODULES[game]) + if hasattr(controller_class, 'REGISTRY'): + # Get all keys from the registry + mods = sorted(controller_class.REGISTRY.keys()) + except Exception as e: + mods = [f"Error loading mods: {e}"] + + if mods: + mod_str = ", ".join(mods) + else: + mod_str = "(no mods)" + + # Wrap the mod string for better readability + wrapped_mods = textwrap.wrap(mod_str, width=57) + + if not wrapped_mods: + print(f"{game:<20} | (no mods)") + else: + print(f"{game:<20} | {wrapped_mods[0]}") + for line in wrapped_mods[1:]: + print(f"{' ':<20} | {line}") + +if __name__ == "__main__": + list_mods() From a9f8111e0e81681b633f967c85b8ccade23a655d Mon Sep 17 00:00:00 2001 From: Quentin Delfosse Date: Sat, 2 May 2026 18:50:08 +0200 Subject: [PATCH 04/28] SpaceInvaders mods --- src/jaxatari/core.py | 3 +- .../spaceinvaders_mod_plugins.py | 108 ++++++++++++++++++ src/jaxatari/games/mods/spaceinvaders_mods.py | 37 ++++++ 3 files changed, 147 insertions(+), 1 deletion(-) create mode 100644 src/jaxatari/games/mods/spaceinvaders/spaceinvaders_mod_plugins.py create mode 100644 src/jaxatari/games/mods/spaceinvaders_mods.py diff --git a/src/jaxatari/core.py b/src/jaxatari/core.py index 70a533848..75ff6562a 100644 --- a/src/jaxatari/core.py +++ b/src/jaxatari/core.py @@ -101,7 +101,8 @@ def _warn_deprecated_obs_to_flat_array(env: JaxEnvironment) -> None: "qbert": "jaxatari.games.mods.qbert_mods.QbertEnvMod", "mspacman": "jaxatari.games.mods.mspacman_mods.MsPacmanEnvMod", "beamrider": "jaxatari.games.mods.beamrider_mods.BeamRiderEnvMod", - "venture": "jaxatari.games.mods.venture_mods.VentureEnvMod" + "venture": "jaxatari.games.mods.venture_mods.VentureEnvMod", + "spaceinvaders": "jaxatari.games.mods.spaceinvaders_mods.SpaceInvadersEnvMod" } diff --git a/src/jaxatari/games/mods/spaceinvaders/spaceinvaders_mod_plugins.py b/src/jaxatari/games/mods/spaceinvaders/spaceinvaders_mod_plugins.py new file mode 100644 index 000000000..45f2c6f99 --- /dev/null +++ b/src/jaxatari/games/mods/spaceinvaders/spaceinvaders_mod_plugins.py @@ -0,0 +1,108 @@ +import jax +import jax.numpy as jnp +from functools import partial +from jaxatari.games.jax_spaceinvaders import SpaceInvadersState +from jaxatari.modification import JaxAtariInternalModPlugin, JaxAtariPostStepModPlugin + +# --- Shield Modifications --- + +class DisableShieldLeftMod(JaxAtariPostStepModPlugin): + """ + Erases the left bunker from the screen by clearing its specific memory data. + """ + @partial(jax.jit, static_argnums=(0,)) + def run(self, prev_state, new_state): + return new_state.replace( + barricade_health=new_state.barricade_health.at[0].set(0) + ) + + @partial(jax.jit, static_argnums=(0,)) + def after_reset(self, obs, state): + state = state.replace( + barricade_health=state.barricade_health.at[0].set(0) + ) + return self._env._get_observation(state), state + +class DisableShieldMiddleMod(JaxAtariPostStepModPlugin): + """ + Erases the middle bunker from the screen by clearing its specific memory data. + """ + @partial(jax.jit, static_argnums=(0,)) + def run(self, prev_state, new_state): + return new_state.replace( + barricade_health=new_state.barricade_health.at[1].set(0) + ) + + @partial(jax.jit, static_argnums=(0,)) + def after_reset(self, obs, state): + state = state.replace( + barricade_health=state.barricade_health.at[1].set(0) + ) + return self._env._get_observation(state), state + +class DisableShieldRightMod(JaxAtariPostStepModPlugin): + """ + Erases the right bunker from the screen by clearing its specific memory data. + """ + @partial(jax.jit, static_argnums=(0,)) + def run(self, prev_state, new_state): + return new_state.replace( + barricade_health=new_state.barricade_health.at[2].set(0) + ) + + @partial(jax.jit, static_argnums=(0,)) + def after_reset(self, obs, state): + state = state.replace( + barricade_health=state.barricade_health.at[2].set(0) + ) + return self._env._get_observation(state), state + +class ShiftShieldsMod(JaxAtariInternalModPlugin): + """ + Teleports all bunkers to new horizontal positions. + """ + constants_overrides = { + # Shifting all bunkers 10 pixels to the right from [41, 73, 105] to [51, 83, 115] + "BARRICADE_POS": (jnp.array([51, 83, 115], dtype=jnp.int32), 157) # 210 - 53 = 157 + } + +# --- Weapon & Gameplay Modifications --- + +class ControllableMissileMod(JaxAtariPostStepModPlugin): + """ + Forces the fired missile's horizontal position to match the player's tank, + allowing you to "steer" shots after firing. + """ + @partial(jax.jit, static_argnums=(0,)) + def run(self, prev_state, new_state): + return new_state.replace( + bullet_x=jnp.where( + new_state.bullet_active, + new_state.player_x - (self._env.consts.PLAYER_SIZE[0] // 2), + new_state.bullet_x + ) + ) + +class NoDangerMod(JaxAtariPostStepModPlugin): + """ + Removes all player shields and neutralizes incoming enemy projectiles. + """ + @partial(jax.jit, static_argnums=(0,)) + def run(self, prev_state, new_state): + # Remove all shields + new_barricade_health = new_state.barricade_health.at[:].set(0) + # Neutralize enemy projectiles (deactivate them) + new_enemy_bullets_active = jnp.zeros_like(new_state.enemy_bullets_active, dtype=jnp.bool_) + + return new_state.replace( + barricade_health=new_barricade_health, + enemy_bullets_active=new_enemy_bullets_active + ) + + @partial(jax.jit, static_argnums=(0,)) + def after_reset(self, obs, state): + state = state.replace( + barricade_health=state.barricade_health.at[:].set(0), + enemy_bullets_active=jnp.zeros_like(state.enemy_bullets_active, dtype=jnp.bool_) + ) + return self._env._get_observation(state), state diff --git a/src/jaxatari/games/mods/spaceinvaders_mods.py b/src/jaxatari/games/mods/spaceinvaders_mods.py new file mode 100644 index 000000000..50715db01 --- /dev/null +++ b/src/jaxatari/games/mods/spaceinvaders_mods.py @@ -0,0 +1,37 @@ +import os +from jaxatari.modification import JaxAtariModController +from jaxatari.games.mods.spaceinvaders.spaceinvaders_mod_plugins import ( + DisableShieldLeftMod, + DisableShieldMiddleMod, + DisableShieldRightMod, + ShiftShieldsMod, + ControllableMissileMod, + NoDangerMod +) + +class SpaceInvadersEnvMod(JaxAtariModController): + """ + Game-specific Mod Controller for SpaceInvaders. + """ + + REGISTRY = { + "disable_shield_left": DisableShieldLeftMod, + "disable_shield_middle": DisableShieldMiddleMod, + "disable_shield_right": DisableShieldRightMod, + "shift_shields": ShiftShieldsMod, + "controllable_missile": ControllableMissileMod, + "no_danger": NoDangerMod, + } + + def __init__(self, + env, + mods_config: list = [], + allow_conflicts: bool = False + ): + + super().__init__( + env=env, + mods_config=mods_config, + allow_conflicts=allow_conflicts, + registry=self.REGISTRY + ) From 30a938982ca78280d885add813f880bccfaaecd3 Mon Sep 17 00:00:00 2001 From: Quentin Delfosse Date: Sat, 2 May 2026 18:58:54 +0200 Subject: [PATCH 05/28] Skiing mods --- src/jaxatari/core.py | 3 +- src/jaxatari/games/jax_skiing.py | 4 +- .../games/mods/skiing/skiing_mod_plugins.py | 194 +++++++++++++++++- src/jaxatari/games/mods/skiing_mods.py | 11 +- src/jaxatari/modification.py | 4 +- 5 files changed, 209 insertions(+), 7 deletions(-) diff --git a/src/jaxatari/core.py b/src/jaxatari/core.py index 75ff6562a..75f7c2ec5 100644 --- a/src/jaxatari/core.py +++ b/src/jaxatari/core.py @@ -102,7 +102,8 @@ def _warn_deprecated_obs_to_flat_array(env: JaxEnvironment) -> None: "mspacman": "jaxatari.games.mods.mspacman_mods.MsPacmanEnvMod", "beamrider": "jaxatari.games.mods.beamrider_mods.BeamRiderEnvMod", "venture": "jaxatari.games.mods.venture_mods.VentureEnvMod", - "spaceinvaders": "jaxatari.games.mods.spaceinvaders_mods.SpaceInvadersEnvMod" + "spaceinvaders": "jaxatari.games.mods.spaceinvaders_mods.SpaceInvadersEnvMod", + "skiing": "jaxatari.games.mods.skiing_mods.SkiingEnvMod" } diff --git a/src/jaxatari/games/jax_skiing.py b/src/jaxatari/games/jax_skiing.py index 78a99247b..26a9c78c1 100644 --- a/src/jaxatari/games/jax_skiing.py +++ b/src/jaxatari/games/jax_skiing.py @@ -314,8 +314,8 @@ def reset(self, key: Optional[jax.random.PRNGKey] = jax.random.PRNGKey(1701)) -> row_spacing = jnp.float32(31.0) base_y = jnp.float32(60.0) - # Flags: r = 3, 7 in the repeating sequence - r_flags = jnp.array([3, 7], dtype=jnp.float32) + # Flags: patterned rows + r_flags = jnp.arange(c.max_num_flags, dtype=jnp.float32) * 4.0 + 3.0 flags_y = base_y + r_flags * row_spacing flags_x = self._get_initial_flags_x() diff --git a/src/jaxatari/games/mods/skiing/skiing_mod_plugins.py b/src/jaxatari/games/mods/skiing/skiing_mod_plugins.py index cf9b3eaa5..93b580adf 100644 --- a/src/jaxatari/games/mods/skiing/skiing_mod_plugins.py +++ b/src/jaxatari/games/mods/skiing/skiing_mod_plugins.py @@ -2,7 +2,9 @@ import jax.numpy as jnp import chex from functools import partial -from jaxatari.modification import JaxAtariInternalModPlugin +from typing import Optional, Tuple +from jaxatari.modification import JaxAtariInternalModPlugin, JaxAtariPostStepModPlugin +from jaxatari.games.jax_skiing import _enforce_min_sep_x, SkiingState, SkiingObservation class MoreTreesMod(JaxAtariInternalModPlugin): """ @@ -99,3 +101,193 @@ def _apply_tree_separation_initial(self, i: chex.Array, x0: chex.Array, tx: chex @partial(jax.jit, static_argnums=(0,)) def _apply_tree_separation_respawn(self, i: chex.Array, x_tree: chex.Array, taken_from_trees: chex.Array, taken_from_moguls: chex.Array, min_sep_tree_tree: chex.Array, min_sep_tree_mogul: chex.Array, xmin_t: chex.Array, xmax_t: chex.Array) -> chex.Array: return x_tree + + +class InvertFlagsMod(JaxAtariPostStepModPlugin): + """ + Flips the orientation or state of the slalom flags on the course. + """ + def run(self, prev_state, new_state): + c = self._env.consts + # Recalculate gate crossing with inverted logic + left_x = prev_state.flags[:, 0] + right_x = left_x + c.flag_distance + new_x = new_state.skier_x + + # Inverted: you MUST go OUTSIDE the flags + eligible = jnp.logical_or(new_x <= left_x, new_x >= right_x) + + crossed = jnp.logical_and(prev_state.flags[:, 1] > c.skier_y, + new_state.flags[:, 1] <= c.skier_y) + + gate_pass = jnp.logical_and(eligible, jnp.logical_and(crossed, jnp.logical_not(prev_state.flags_passed))) + + # Undo original scoring from step + orig_eligible = jnp.logical_and(new_x > left_x, new_x < right_x) + orig_gate_pass = jnp.logical_and(orig_eligible, jnp.logical_and(crossed, jnp.logical_not(prev_state.flags_passed))) + + corrected_score = new_state.successful_gates + jnp.sum(orig_gate_pass) - jnp.sum(gate_pass) + new_flags_passed = jnp.logical_or(prev_state.flags_passed, gate_pass) + + return new_state.replace(successful_gates=corrected_score, flags_passed=new_flags_passed) + + +class MovingFlagsMod(JaxAtariPostStepModPlugin): + """ + Causes the flags to slide horizontally across the screen while you ski. + """ + def run(self, prev_state, new_state): + c = self._env.consts + # Move flags sinusoidally + shift = jnp.sin(new_state.step_count.astype(jnp.float32) * 0.05) * 1.5 + new_flags_x = new_state.flags[:, 0] + shift + + # Keep them within bounds + min_fx = c.border_left + max_fx = c.screen_width - c.border_right - c.flag_distance + new_flags_x = jnp.clip(new_flags_x, min_fx, max_fx) + + new_flags = new_state.flags.at[:, 0].set(new_flags_x) + return new_state.replace(flags=new_flags) + + +class RandomFlagsMod(JaxAtariInternalModPlugin, JaxAtariPostStepModPlugin): + """ + Randomizes the horizontal placement of every flag. + """ + def after_reset(self, obs, state): + c = self._env.consts + key, subkey = jax.random.split(state.key) + min_fx = c.border_left + 20 + max_fx = c.screen_width - c.border_right - c.flag_distance - 20 + new_x = jax.random.uniform(subkey, (c.max_num_flags,), minval=min_fx, maxval=max_fx) + new_flags = state.flags.at[:, 0].set(new_x) + return obs, state.replace(flags=new_flags, key=key) + + def run(self, prev_state, new_state): + c = self._env.consts + # Detect if flags respawned + respawned = new_state.flags[:, 1] > prev_state.flags[:, 1] + 100 + key, subkey = jax.random.split(new_state.key) + min_fx = c.border_left + 20 + max_fx = c.screen_width - c.border_right - c.flag_distance - 20 + rand_x = jax.random.uniform(subkey, (c.max_num_flags,), minval=min_fx, maxval=max_fx) + + new_x = jnp.where(respawned, rand_x, new_state.flags[:, 0]) + new_flags = new_state.flags.at[:, 0].set(new_x) + return new_state.replace(flags=new_flags, key=key) + + +class FlagFlurryMod(JaxAtariInternalModPlugin): + """ + Dramatically increases the number of flags on the mountain. + """ + constants_overrides = { + "max_num_flags": 8, + } + + @partial(jax.jit, static_argnums=(0,)) + def reset(self, key: Optional[jax.random.PRNGKey] = jax.random.PRNGKey(1701)) -> Tuple[SkiingObservation, SkiingState]: + c = self._env.consts + _, new_key = jax.random.split(key, 2) + + row_spacing = jnp.float32(20.0) # More frequent + base_y = jnp.float32(60.0) + + # Flags: more rows + r_flags = jnp.arange(c.max_num_flags, dtype=jnp.float32) * 2.0 + 3.0 + flags_y = base_y + r_flags * row_spacing + + flags_x = self._env._get_initial_flags_x() + flags = jnp.stack([flags_x, flags_y], axis=1) + + # Trees + trees_x = self._env._get_initial_trees_x() + trees_per_row = jnp.maximum(1, c.max_num_trees // 4) + i_t = jnp.arange(c.max_num_trees, dtype=jnp.int32) + row_idx_t = i_t // trees_per_row + base_offsets_t = jnp.array([0, 1, 4, 5], dtype=jnp.float32) + r_trees = (row_idx_t // 4) * 8.0 + jnp.take(base_offsets_t, row_idx_t % 4) + trees_y = base_y + r_trees * jnp.float32(31.0) # Original row spacing for trees + stagger_t = ((i_t * 7) % 15).astype(jnp.float32) - 7.0 + trees_y = trees_y + stagger_t + + min_sep_tree = 0.5*(jnp.float32(c.tree_width)+jnp.float32(c.tree_width)) + jnp.float32(c.sep_margin_tree_tree) + xmin = jnp.float32(c.border_left) + xmax = jnp.float32(c.screen_width - c.border_right) + + def adj_tree_i(i, tx): + x0 = tx[i] + x_adj = _enforce_min_sep_x(x0, tx, min_sep_tree, xmin, xmax, n_valid=jnp.array(i, dtype=jnp.int32)) + return tx.at[i].set(self._env._enforce_tree_gap(x_adj)) + + trees_x = jax.lax.fori_loop(0, c.max_num_trees, adj_tree_i, trees_x) + trees_type = jnp.arange(c.max_num_trees, dtype=jnp.float32) % 4.0 + trees = jnp.stack([trees_x, trees_y, trees_type], axis=1) + + # Moguls + min_rx = jnp.int32(c.border_left + 50) + max_rx = jnp.int32(c.screen_width - c.border_right - 50) + span_rx = max_rx - min_rx + 1 + moguls_x = (min_rx + ((jnp.arange(c.max_num_moguls, dtype=jnp.int32) * 19) % span_rx)).astype(jnp.float32) + moguls_per_row = jnp.maximum(1, c.max_num_moguls // 2) + i_r = jnp.arange(c.max_num_moguls, dtype=jnp.int32) + row_idx_r = i_r // moguls_per_row + base_offsets_r = jnp.array([2, 6], dtype=jnp.float32) + r_moguls = (row_idx_r // 2) * 8.0 + jnp.take(base_offsets_r, row_idx_r % 2) + moguls_y = base_y + r_moguls * jnp.float32(31.0) + stagger_r = ((i_r * 11) % 15).astype(jnp.float32) - 7.0 + moguls_y = moguls_y + stagger_r + + min_sep_mogul_tree = 0.5*(jnp.float32(c.mogul_width)+jnp.float32(c.tree_width)) + jnp.float32(c.sep_margin_tree_mogul) + min_sep_mogul_mogul = 0.5*(jnp.float32(c.mogul_width)+jnp.float32(c.mogul_width)) + jnp.float32(c.sep_margin_mogul_mogul) + xmin_r = jnp.float32(c.border_left + 50) + xmax_r = jnp.float32(c.screen_width - c.border_right - 50) + tree_xs_fixed = trees[:, 0] + + def adj_mogul_i(i, rx): + x0 = rx[i] + x1 = _enforce_min_sep_x(x0, tree_xs_fixed, min_sep_mogul_tree, xmin_r, xmax_r, n_valid=jnp.array(tree_xs_fixed.shape[0], dtype=jnp.int32)) + x2 = _enforce_min_sep_x(x1, rx, min_sep_mogul_mogul, xmin_r, xmax_r, n_valid=jnp.array(i, dtype=jnp.int32)) + return rx.at[i].set(x2) + + moguls_x = jax.lax.fori_loop(0, c.max_num_moguls, adj_mogul_i, moguls_x) + moguls = jnp.stack([moguls_x, moguls_y], axis=1) + + state = SkiingState( + skier_x=jnp.array(76.0), + skier_pos=jnp.array(4, dtype=jnp.int32), + skier_fell=jnp.array(0, dtype=jnp.int32), + skier_x_speed=jnp.array(0.0), + skier_y_speed=jnp.array(0.0), + flags=flags, + trees=trees, + moguls=moguls, + successful_gates=jnp.array(20, dtype=jnp.int32), + step_count=jnp.array(0, dtype=jnp.int32), + direction_change_counter=jnp.array(0, dtype=jnp.int32), + game_over=jnp.array(False), + key=new_key, + collision_type=jnp.array(0, dtype=jnp.int32), + flags_passed=jnp.zeros(c.max_num_flags, dtype=bool), + collision_cooldown=jnp.array(0, dtype=jnp.int32), + skier_just_respawned=jnp.array(False, dtype=jnp.bool_), + jump_timer=jnp.array(0, dtype=jnp.int32), + is_jumping=jnp.array(False, dtype=jnp.bool_), + gates_seen=jnp.array(0, dtype=jnp.int32), + ) + obs = self._env._get_observation(state) + return obs, state + + +class MogulsToTreesMod(JaxAtariInternalModPlugin): + """ + Transforms all the small snow bumps (moguls) into solid trees. + """ + constants_overrides = { + "moguls_collidable": True, + "mogul_height": 30, + } + asset_overrides = { + "mogul": {'name': 'mogul', 'type': 'single', 'file': 'tree_0.npy'} + } diff --git a/src/jaxatari/games/mods/skiing_mods.py b/src/jaxatari/games/mods/skiing_mods.py index c58b6b869..95f2812e7 100644 --- a/src/jaxatari/games/mods/skiing_mods.py +++ b/src/jaxatari/games/mods/skiing_mods.py @@ -1,6 +1,10 @@ import os from jaxatari.modification import JaxAtariModController -from jaxatari.games.mods.skiing.skiing_mod_plugins import MoreTreesMod, MoreMogulsMod, DangerousMogulsMod, JumpToBreakMod, SpeedBurstMod, TreesEverywhereMod, HallOfFameMod +from jaxatari.games.mods.skiing.skiing_mod_plugins import ( + MoreTreesMod, MoreMogulsMod, DangerousMogulsMod, JumpToBreakMod, + SpeedBurstMod, TreesEverywhereMod, HallOfFameMod, + InvertFlagsMod, MovingFlagsMod, RandomFlagsMod, FlagFlurryMod, MogulsToTreesMod +) class SkiingEnvMod(JaxAtariModController): """ @@ -16,6 +20,11 @@ class SkiingEnvMod(JaxAtariModController): "jump_to_break": JumpToBreakMod, "speed_burst": SpeedBurstMod, "hall_of_fame": HallOfFameMod, + "invert_flags": InvertFlagsMod, + "moving_flags": MovingFlagsMod, + "random_flags": RandomFlagsMod, + "flag_flurry": FlagFlurryMod, + "moguls_to_trees": MogulsToTreesMod, "off_piste": ["_more_trees", "_trees_everywhere", "_more_moguls", "_dangerous_moguls"], } diff --git a/src/jaxatari/modification.py b/src/jaxatari/modification.py index 278238efd..97120ab44 100644 --- a/src/jaxatari/modification.py +++ b/src/jaxatari/modification.py @@ -629,7 +629,7 @@ def __init__(self, # Attribute overrides will be collected and applied later via helper function # Build patch map for functional conflicts for fn_name, _ in inspect.getmembers(plugin_instance, predicate=inspect.ismethod): - if not fn_name.startswith("__"): + if not fn_name.startswith("__") and fn_name not in ["run", "after_reset"]: if not (hasattr(self._env, fn_name) or (hasattr(self._env, 'renderer') and hasattr(self._env.renderer, fn_name))): raise AttributeError( f"Mod '{mod_key}' tries to patch '{fn_name}', but neither env nor renderer define it." @@ -696,7 +696,7 @@ def __init__(self, # Apply Function Patches (to env OR renderer) for fn_name, fn_logic in inspect.getmembers(plugin, predicate=inspect.ismethod): - if not fn_name.startswith("__"): + if not fn_name.startswith("__") and fn_name not in ["run", "after_reset"]: # Use the bound method directly; jit(static_argnums=(0,)) expects the instance as arg 0 env_has_attr = hasattr(self._env, fn_name) From 7f2cf79da2c4635d6e17431da0c8d7dd64c8c454 Mon Sep 17 00:00:00 2001 From: Quentin Delfosse Date: Sun, 3 May 2026 13:26:16 +0200 Subject: [PATCH 06/28] Mods for Skiing and Freeway --- src/jaxatari/games/jax_skiing.py | 30 ++++++--- .../games/mods/freeway/freeway_mod_plugins.py | 57 ++++++++++++++++++ src/jaxatari/games/mods/freeway_mods.py | 3 +- .../games/mods/skiing/skiing_mod_plugins.py | 54 +++++++++++++++++ .../mods/skiing/sprites/blue_skier_fallen.npy | Bin 0 -> 848 bytes .../mods/skiing/sprites/blue_skiier_0.npy | Bin 0 -> 640 bytes .../mods/skiing/sprites/blue_skiier_1.npy | Bin 0 -> 608 bytes .../mods/skiing/sprites/blue_skiier_2.npy | Bin 0 -> 632 bytes .../mods/skiing/sprites/blue_skiier_3.npy | Bin 0 -> 632 bytes .../mods/skiing/sprites/blue_skiier_4.npy | Bin 0 -> 632 bytes .../mods/skiing/sprites/blue_skiier_5.npy | Bin 0 -> 632 bytes .../mods/skiing/sprites/blue_skiier_6.npy | Bin 0 -> 608 bytes .../mods/skiing/sprites/blue_skiier_7.npy | Bin 0 -> 640 bytes src/jaxatari/games/mods/skiing_mods.py | 12 +++- 14 files changed, 145 insertions(+), 11 deletions(-) create mode 100644 src/jaxatari/games/mods/skiing/sprites/blue_skier_fallen.npy create mode 100644 src/jaxatari/games/mods/skiing/sprites/blue_skiier_0.npy create mode 100644 src/jaxatari/games/mods/skiing/sprites/blue_skiier_1.npy create mode 100644 src/jaxatari/games/mods/skiing/sprites/blue_skiier_2.npy create mode 100644 src/jaxatari/games/mods/skiing/sprites/blue_skiier_3.npy create mode 100644 src/jaxatari/games/mods/skiing/sprites/blue_skiier_4.npy create mode 100644 src/jaxatari/games/mods/skiing/sprites/blue_skiier_5.npy create mode 100644 src/jaxatari/games/mods/skiing/sprites/blue_skiier_6.npy create mode 100644 src/jaxatari/games/mods/skiing/sprites/blue_skiier_7.npy diff --git a/src/jaxatari/games/jax_skiing.py b/src/jaxatari/games/jax_skiing.py index 26a9c78c1..deb533076 100644 --- a/src/jaxatari/games/jax_skiing.py +++ b/src/jaxatari/games/jax_skiing.py @@ -112,7 +112,7 @@ def upsample(bits): 'tree_2.npy', 'tree_3.npy' ]}, - {'name': 'mogul', 'type': 'single', 'file': 'stone.npy'}, + {'name': 'mogul', 'type': 'single', 'file': 'mogul.npy'}, # UI {'name': 'digits', 'type': 'procedural', 'data': procedural_digits}, @@ -127,6 +127,9 @@ class SkiingConstants(AutoDerivedConstants): USE_ORIGINAL_ALE_REWARD: bool = struct.field(pytree_node=False, default=True) BOTTOM_BORDER: int = struct.field(pytree_node=False, default=176) TOP_BORDER: int = struct.field(pytree_node=False, default=-15) + invert_flag_colors: bool = struct.field(pytree_node=False, default=False) + green_flags: bool = struct.field(pytree_node=False, default=False) + blue_skier: bool = struct.field(pytree_node=False, default=False) """Game configuration parameters""" screen_width: int = struct.field(pytree_node=False, default=160) screen_height: int = struct.field(pytree_node=False, default=210) @@ -986,12 +989,24 @@ def __init__(self, consts: SkiingConstants = None, config: render_utils.Renderer self.jr = render_utils.JaxRenderingUtils(self.config) # 2. Start from (possibly modded) asset config provided via constants - final_asset_config = list(self.consts.ASSET_CONFIG) + final_asset_config = [] + for asset in self.consts.ASSET_CONFIG: + new_asset = dict(asset) + if asset.get('name') == 'skier_group' and getattr(self.consts, "blue_skier", False): + new_asset['recolorings'] = {'blue': (0, 0, 255)} + final_asset_config.append(new_asset) # 3. Load flags (needs sprite path, so done here) flag_red_rgba = self._load_rgba_sprite("checkered_flag_red.npy") flag_blue_rgba = self._load_rgba_sprite("checkered_flag_blue.npy") + if getattr(self.consts, "invert_flag_colors", False): + flag_red_rgba, flag_blue_rgba = flag_blue_rgba, flag_red_rgba + + if getattr(self.consts, "green_flags", False): + flag_red_rgba = np.array(self.jr.perform_recoloring(jnp.array(flag_red_rgba), (50, 200, 50))) + flag_blue_rgba = np.array(self.jr.perform_recoloring(jnp.array(flag_blue_rgba), (50, 200, 50))) + # Pad them so they have the same shape for jax.lax.select max_h = max(flag_red_rgba.shape[0], flag_blue_rgba.shape[0]) max_w = max(flag_red_rgba.shape[1], flag_blue_rgba.shape[1]) @@ -1032,6 +1047,10 @@ def __init__(self, consts: SkiingConstants = None, config: render_utils.Renderer self.BACKGROUND = jnp.array(bg) # 5. Store key color/shape IDs + if 'skier_group_blue' in self.SHAPE_MASKS: + self.SHAPE_MASKS['skier_group'] = self.SHAPE_MASKS['skier_group_blue'] + self.FLIP_OFFSETS['skier_group'] = self.FLIP_OFFSETS['skier_group_blue'] + self.RED_FLAG_MASK = self.SHAPE_MASKS['flag_red'] self.BLUE_FLAG_MASK = self.SHAPE_MASKS['flag_blue'] self.RED_FLAG_OFFSET = self.FLIP_OFFSETS['flag_red'] @@ -1054,13 +1073,6 @@ def _load_rgba_sprite(self, file_name: str) -> np.ndarray: a = np.full(rgba.shape[:2] + (1,), 255, np.uint8) rgba = np.concatenate([rgba, a], axis=-1) return rgba - - def _recolor_rgba(self, sprite_rgba: np.ndarray, rgb: Tuple[int,int,int]) -> np.ndarray: - """Manually recolors an RGBA sprite. For setup only.""" - mask = (sprite_rgba[..., 3:4] > 0) - rgb_arr = np.array(rgb, dtype=np.uint8)[None, None, :] - new_rgb = np.where(mask, rgb_arr, sprite_rgba[..., :3]) - return np.concatenate([new_rgb, sprite_rgba[..., 3:4]], axis=-1) @partial(jax.jit, static_argnums=(0,)) def _format_score_digits(self, score: jnp.ndarray) -> jnp.ndarray: diff --git a/src/jaxatari/games/mods/freeway/freeway_mod_plugins.py b/src/jaxatari/games/mods/freeway/freeway_mod_plugins.py index f67f21583..28c62135b 100644 --- a/src/jaxatari/games/mods/freeway/freeway_mod_plugins.py +++ b/src/jaxatari/games/mods/freeway/freeway_mod_plugins.py @@ -138,3 +138,60 @@ def after_reset(self, obs, state: FreewayState): # Return modified observation and state modified_state = state.replace(cars=centered_cars) return obs, modified_state + + +import os +from jaxatari.rendering.jax_rendering_utils import JaxRenderingUtils, RendererConfig, get_base_sprite_dir + +# Initialize utilities +_jr = JaxRenderingUtils(RendererConfig()) +_bike_path = os.path.join(get_base_sprite_dir(), "freeway", "bike.npy") +_bike_array = _jr.loadFrame(_bike_path) + +# Define distinct color pairs for 10 lanes (Biker, Motorbike) +_color_pairs = [ + ((255, 0, 0), (0, 0, 255)), # Lane 0: Red / Blue + ((0, 255, 0), (255, 255, 0)), # Lane 1: Green / Yellow + ((255, 0, 255), (0, 255, 255)), # Lane 2: Magenta / Cyan + ((255, 128, 0), (128, 0, 255)), # Lane 3: Orange / Purple + ((255, 255, 255), (0, 0, 0)), # Lane 4: White / Black + ((0, 0, 255), (255, 0, 0)), # Lane 5: Blue / Red + ((255, 255, 0), (0, 255, 0)), # Lane 6: Yellow / Green + ((0, 255, 255), (255, 0, 255)), # Lane 7: Cyan / Magenta + ((128, 0, 255), (255, 128, 0)), # Lane 8: Purple / Orange + ((0, 0, 0), (255, 255, 255)), # Lane 9: Black / White +] + +_recolored_bikes = [] +for _biker_color, _motorbike_color in _color_pairs: + _rule = [ + {'source': (80, 184, 57), 'target': _biker_color}, + {'source': (32, 167, 32), 'target': _biker_color}, + {'source': (234, 61, 49), 'target': _motorbike_color}, + {'source': (255, 32, 32), 'target': _motorbike_color} + ] + _recolored_bikes.append(_jr.perform_recoloring(_bike_array, _rule)) + + +class BikesMod(JaxAtariInternalModPlugin): + """Replaces all cars with uniquely colored bike sprites.""" + constants_overrides = { + "ASSET_CONFIG": ( + {'name': 'background', 'type': 'background', 'file': 'background.npy'}, + { + 'name': 'player', 'type': 'group', + 'files': ['player_hit.npy', 'player_walk.npy', 'player_idle.npy'] + }, + {'name': 'car_dark_red', 'type': 'procedural', 'data': _recolored_bikes[0]}, + {'name': 'car_light_green', 'type': 'procedural', 'data': _recolored_bikes[1]}, + {'name': 'car_dark_green', 'type': 'procedural', 'data': _recolored_bikes[2]}, + {'name': 'car_light_red', 'type': 'procedural', 'data': _recolored_bikes[3]}, + {'name': 'car_blue', 'type': 'procedural', 'data': _recolored_bikes[4]}, + {'name': 'car_brown', 'type': 'procedural', 'data': _recolored_bikes[5]}, + {'name': 'car_light_blue', 'type': 'procedural', 'data': _recolored_bikes[6]}, + {'name': 'car_red', 'type': 'procedural', 'data': _recolored_bikes[7]}, + {'name': 'car_green', 'type': 'procedural', 'data': _recolored_bikes[8]}, + {'name': 'car_yellow', 'type': 'procedural', 'data': _recolored_bikes[9]}, + {'name': 'score_digits', 'type': 'digits', 'pattern': 'score_{}.npy'}, + ) + } diff --git a/src/jaxatari/games/mods/freeway_mods.py b/src/jaxatari/games/mods/freeway_mods.py index 014dfef56..bba28562f 100644 --- a/src/jaxatari/games/mods/freeway_mods.py +++ b/src/jaxatari/games/mods/freeway_mods.py @@ -1,6 +1,6 @@ import os from jaxatari.modification import JaxAtariModController -from jaxatari.games.mods.freeway.freeway_mod_plugins import StopAllCarsMod, StaticCarsMod, SlowCarsMod, BlackCarsMod, CenterCarsOnResetMod, InvertSpeed, HallOfFameMod +from jaxatari.games.mods.freeway.freeway_mod_plugins import StopAllCarsMod, StaticCarsMod, SlowCarsMod, BlackCarsMod, CenterCarsOnResetMod, InvertSpeed, HallOfFameMod, BikesMod class FreewayEnvMod(JaxAtariModController): """ @@ -17,6 +17,7 @@ class FreewayEnvMod(JaxAtariModController): "center_cars_on_reset": CenterCarsOnResetMod, "hall_of_fame": ["_hall_of_fame_start", "static_cars"], "_hall_of_fame_start": HallOfFameMod, + "bikes": BikesMod, } _mod_sprite_dir = os.path.join(os.path.dirname(__file__), "freeway", "sprites") diff --git a/src/jaxatari/games/mods/skiing/skiing_mod_plugins.py b/src/jaxatari/games/mods/skiing/skiing_mod_plugins.py index 93b580adf..864080960 100644 --- a/src/jaxatari/games/mods/skiing/skiing_mod_plugins.py +++ b/src/jaxatari/games/mods/skiing/skiing_mod_plugins.py @@ -132,6 +132,15 @@ def run(self, prev_state, new_state): return new_state.replace(successful_gates=corrected_score, flags_passed=new_flags_passed) +class InvertFlagColorsMod(JaxAtariInternalModPlugin): + """ + Inverts the colors of the slalom flags (all flags become red, the last one becomes blue). + """ + constants_overrides = { + "invert_flag_colors": True, + } + + class MovingFlagsMod(JaxAtariPostStepModPlugin): """ Causes the flags to slide horizontally across the screen while you ski. @@ -291,3 +300,48 @@ class MogulsToTreesMod(JaxAtariInternalModPlugin): asset_overrides = { "mogul": {'name': 'mogul', 'type': 'single', 'file': 'tree_0.npy'} } + + +class ClassicTreesMod(JaxAtariInternalModPlugin): + """ + Replaces the default tree sprites with classic versions. + """ + asset_overrides = { + "tree_group": { + 'name': 'tree_group', + 'type': 'group', + 'files': [ + 'classic_tree_0.npy', + 'classic_tree_1.npy', + 'classic_tree_2.npy', + 'classic_tree_3.npy' + ] + } + } + + +class ThinMogulsMod(JaxAtariInternalModPlugin): + """ + Replaces the standard moguls with thinner ones. + """ + asset_overrides = { + "mogul": {'name': 'mogul', 'type': 'single', 'file': 'thin_mogul.npy'} + } + + +class BlueSkiierMod(JaxAtariInternalModPlugin): + """ + Replaces the skier sprites with blue versions. + """ + constants_overrides = { + "blue_skier": True + } + + +class GreenFlagsMod(JaxAtariInternalModPlugin): + """ + Recolors all flags to green. + """ + constants_overrides = { + "green_flags": True + } diff --git a/src/jaxatari/games/mods/skiing/sprites/blue_skier_fallen.npy b/src/jaxatari/games/mods/skiing/sprites/blue_skier_fallen.npy new file mode 100644 index 0000000000000000000000000000000000000000..f882835310cebd1b7f8a68ab77c9d882bbd2ecc0 GIT binary patch literal 848 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1JQ);NLqoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$-7CH)s2098RnmP)#3giGTh9Lmd!xtFuSptfvg{09X9naF=Evc MqYmZ literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/mods/skiing/sprites/blue_skiier_0.npy b/src/jaxatari/games/mods/skiing/sprites/blue_skiier_0.npy new file mode 100644 index 0000000000000000000000000000000000000000..3c843e96a5cce5b1a5aea5add8480e4ea814d8f5 GIT binary patch literal 640 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1JQ);NLqoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$-W;zNMItnJ5ItsN4WC1SqxNrYKm;nbMR-PC$aH)aW0i)5K0uzVnhtbID zU}|7u=xSi{=;FxgVQOGv$TUnJj7F9Nv0-{~$rDlo(+^VzqCo&Aj*W(>KqX=7kVRo) G*k}MrD50(Z literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/mods/skiing/sprites/blue_skiier_1.npy b/src/jaxatari/games/mods/skiing/sprites/blue_skiier_1.npy new file mode 100644 index 0000000000000000000000000000000000000000..2c312bd3077b1d64224b43aa5894e2e0164104c9 GIT binary patch literal 608 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1JQ);NLqoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$-raB51ItnJ5ItsN4WC1SqxNrYKm;ncX$-`)D%3yLZ8eI)c97KcUU^KEi zd}?6o(bXX9hnWjggDyv`JWL%f{lv<{+zc}Z<}MhGO&n%6ObtFWVEWPZAd7?80JUJ8 Axc~qF literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/mods/skiing/sprites/blue_skiier_2.npy b/src/jaxatari/games/mods/skiing/sprites/blue_skiier_2.npy new file mode 100644 index 0000000000000000000000000000000000000000..08bcbcefad379faa2ee9cf02274aa7aad9dcb9a0 GIT binary patch literal 632 zcmbV_u?oU47=-KUQ)K%UvJ?cx2XM2vIJhXmMjXUSqAu!F_^z(npK@@7X7Pvg{(N_r zzN__ewHXxm;vuz<`yg{8&s!;Lf`y)L!|*ouY4fJi3c-@1t^N&D5I7*qWF5 zYlgD@eDxGP0}sj(FxH%}TyyZ#4m{>y+`&7@!GnUQUG!MdVV3LRUCv^M2W6h_xI_7W TxjwS0I^5|<_vJow-TOK}KfSII literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/mods/skiing/sprites/blue_skiier_3.npy b/src/jaxatari/games/mods/skiing/sprites/blue_skiier_3.npy new file mode 100644 index 0000000000000000000000000000000000000000..547de3d341a1e57a60b5cd3a5bff76881bc8352d GIT binary patch literal 632 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1JQ);NLqoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$-7CH*%ItnJ5ItsN4WC1P)C{U04_8){{d}JDz984TWW77kZ1JNLPWIi@F zx*B9YG8^4IbUB#WAaM|Y@nJMX6hdM%3nq`O9;Oam4p|LM4n)Jm(e=UPkZBMb24L!8 OG)x>s!^B`TxpDw#^SAi` literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/mods/skiing/sprites/blue_skiier_4.npy b/src/jaxatari/games/mods/skiing/sprites/blue_skiier_4.npy new file mode 100644 index 0000000000000000000000000000000000000000..3718557af6412c9ced4f2a4f3f64530282ea0b22 GIT binary patch literal 632 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1JQ);NLqoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$-7CH*%ItnJ5ItsN4WC1P)B%mJm?LP=33Bj4Ti5omjj7|08AXl1`9$7bUV=HV0vNdVDh-cLGm!QFdC+Z-f{rV CqPO_~ literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/mods/skiing/sprites/blue_skiier_5.npy b/src/jaxatari/games/mods/skiing/sprites/blue_skiier_5.npy new file mode 100644 index 0000000000000000000000000000000000000000..36a278ddbf0af92dcef36eb832184cb4f767fa83 GIT binary patch literal 632 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1JQ);NLqoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$-7CH*%ItnJ5ItsN4WC1P)EI>W(+kX&7mV$}FXqX&~k4%H)KztYtlS9^n zOAc8LOdrfFYWoQnFCXU%tvN} HJd6(j0S~Sb literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/mods/skiing/sprites/blue_skiier_6.npy b/src/jaxatari/games/mods/skiing/sprites/blue_skiier_6.npy new file mode 100644 index 0000000000000000000000000000000000000000..b22f2c77232d976532cec46b9a0a36f88a10d72b GIT binary patch literal 608 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1JQ);NLqoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$-raB51ItnJ5ItsN4WC1P)96&wp+kX&7mxhVMXk>Yq7>q`j1My*UAQ~nQ zV`Gy;Rs)lVnFEu9@sZ_`<%pGssYBO`EDm!kvKnN2vB|^q!_*>jtR7h%G4e1wkoCdX`1IqGLpBqp4rV4y4knIH!_>jVVSE@36URn_)Bpgx CE1|9c literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/mods/skiing_mods.py b/src/jaxatari/games/mods/skiing_mods.py index 95f2812e7..601bf8bd5 100644 --- a/src/jaxatari/games/mods/skiing_mods.py +++ b/src/jaxatari/games/mods/skiing_mods.py @@ -3,7 +3,8 @@ from jaxatari.games.mods.skiing.skiing_mod_plugins import ( MoreTreesMod, MoreMogulsMod, DangerousMogulsMod, JumpToBreakMod, SpeedBurstMod, TreesEverywhereMod, HallOfFameMod, - InvertFlagsMod, MovingFlagsMod, RandomFlagsMod, FlagFlurryMod, MogulsToTreesMod + InvertFlagsMod, InvertFlagColorsMod, MovingFlagsMod, RandomFlagsMod, FlagFlurryMod, MogulsToTreesMod, + ClassicTreesMod, ThinMogulsMod, BlueSkiierMod, GreenFlagsMod ) class SkiingEnvMod(JaxAtariModController): @@ -12,6 +13,9 @@ class SkiingEnvMod(JaxAtariModController): It inherits all logic from JaxAtariModController and defines the REGISTRY. """ + # Define the path relative to this file (mod sprites fallback) + _mod_sprite_dir = os.path.join(os.path.dirname(__file__), "skiing", "sprites") + REGISTRY = { "_more_trees": MoreTreesMod, "_trees_everywhere": TreesEverywhereMod, @@ -21,11 +25,17 @@ class SkiingEnvMod(JaxAtariModController): "speed_burst": SpeedBurstMod, "hall_of_fame": HallOfFameMod, "invert_flags": InvertFlagsMod, + "invert_flag_colors": InvertFlagColorsMod, "moving_flags": MovingFlagsMod, "random_flags": RandomFlagsMod, "flag_flurry": FlagFlurryMod, "moguls_to_trees": MogulsToTreesMod, + "classic_trees": ClassicTreesMod, + "thin_moguls": ThinMogulsMod, + "blue_skiier": BlueSkiierMod, + "green_flags": GreenFlagsMod, "off_piste": ["_more_trees", "_trees_everywhere", "_more_moguls", "_dangerous_moguls"], + "change_sprites": ["classic_trees", "thin_moguls", "blue_skiier", "green_flags"], } def __init__(self, From e931cb7888c55901a789f9a3a191fdf97ed857c9 Mon Sep 17 00:00:00 2001 From: Quentin Delfosse Date: Sun, 3 May 2026 14:32:32 +0200 Subject: [PATCH 07/28] Frostbite mods --- games_covered.md | 2 +- src/jaxatari/games/jax_frostbite.py | 132 +++++++++++------- .../games/mods/freeway/freeway_mod_plugins.py | 75 +++++++--- src/jaxatari/games/mods/freeway_mods.py | 6 +- .../mods/frostbite/frostbite_mod_plugins.py | 74 +++++++++- src/jaxatari/games/mods/frostbite_mods.py | 12 +- 6 files changed, 223 insertions(+), 78 deletions(-) diff --git a/games_covered.md b/games_covered.md index 1fc85b232..883f5112d 100644 --- a/games_covered.md +++ b/games_covered.md @@ -40,7 +40,7 @@ Total:  🥇: 1 | 🥈: 3 | 🥉: 1  |  ❌: 1 | enduro | 🥇 | 0 | | fishing_derby | 🥇 | 0 | | freeway | 🥇 | 3 | -| frostbite | 🥇 | 0 | +| frostbite | 🥇 | 11 | | gopher | ❌ | 0 | | gravitar | 🥇 | 0 | | hero | ❌ | 0 | diff --git a/src/jaxatari/games/jax_frostbite.py b/src/jaxatari/games/jax_frostbite.py index 49d1422a6..6dab7f7f3 100644 --- a/src/jaxatari/games/jax_frostbite.py +++ b/src/jaxatari/games/jax_frostbite.py @@ -96,11 +96,31 @@ class FrostbiteConstants(struct.PyTreeNode): # RGB Overrides for mods (if set, overrides the actual rendered color of the ice blocks) RGB_ICE_WHITE: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) RGB_ICE_BLUE: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + + # RGB Overrides for obstacles + RGB_FISH: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_GEESE: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_CRAB: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_CLAM: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + + # Sprite overrides + BEAR_SPRITE_0: str = struct.field(pytree_node=False, default="bear_00.npy") + BEAR_SPRITE_1: str = struct.field(pytree_node=False, default="bear_01.npy") + + # Igloo overrides + IGLOO_X_OFFSET: int = struct.field(pytree_node=False, default=0) + RGB_IGLOO: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + TARGET_IGLOO_X: int = struct.field(pytree_node=False, default=123) # Igloo constants IGLOO_X: int = struct.field(pytree_node=False, default=154) # X position of igloo (far right side of screen) IGLOO_X: int = struct.field(pytree_node=False, default=154) IGLOO_Y: int = struct.field(pytree_node=False, default=44) # Y position at top of Bailey's head when on shore + # Environment mode overrides + CONSTANT_NIGHT: bool = struct.field(pytree_node=False, default=False) + RGB_NIGHT: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + DRAW_SHORE_LINE: bool = struct.field(pytree_node=False, default=False) + # Game Constants MAX_IGLOO_INDEX: int = struct.field(pytree_node=False, default=15) # Complete igloo has 16 blocks (0-15) MAX_EATEN_FISH: int = struct.field(pytree_node=False, default=12) # Max fish that can be eaten per level @@ -1566,7 +1586,7 @@ def _update_bailey(self, state: FrostbiteState, action: int): x_max = self.consts.SHORE_X_MAX # Automatic movement toward igloo during entry sequence - target_igloo_x = 123 # Fixed igloo X position + target_igloo_x = self.consts.TARGET_IGLOO_X # Mod-aware igloo X position auto_dx = jnp.where( is_entering_igloo, jnp.sign(target_igloo_x - state.bailey_x) * jnp.minimum(2, jnp.abs(target_igloo_x - state.bailey_x)), @@ -1611,7 +1631,7 @@ def _update_bailey(self, state: FrostbiteState, action: int): can_start_jump = state.bailey_jumping_idx == 0 # Not already jumping # Special case: automatic jump when entering igloo - at_igloo_x = jnp.abs(state.bailey_x - 123) <= 1 + at_igloo_x = jnp.abs(state.bailey_x - target_igloo_x) <= 1 should_jump_for_igloo = is_entering_igloo & at_igloo_x # Determine jump intent using the refactored method @@ -2854,7 +2874,7 @@ def _check_igloo_entry(self, state: FrostbiteState, action: int): not_entering = state.igloo_entry_status == 0 # Door collision box (igloo door is at specific X position) - door_x = 122 + door_x = self.consts.TARGET_IGLOO_X - 1 bailey_right_edge = state.bailey_x + 16 near_door = (bailey_right_edge >= door_x) & (state.bailey_x <= door_x + 8) @@ -2998,8 +3018,8 @@ def _load_and_prepare_assets(self): crab_1 = self._load_frame_legacy("king_crab_01.npy") clam_0 = jnp.flip(self._load_frame_legacy("clam_00.npy"), axis=1) clam_1 = jnp.flip(self._load_frame_legacy("clam_01.npy"), axis=1) - bear_0 = self._load_frame_legacy("bear_00.npy") - bear_1 = self._load_frame_legacy("bear_01.npy") + bear_0 = self._load_frame_legacy(self.consts.BEAR_SPRITE_0) + bear_1 = self._load_frame_legacy(self.consts.BEAR_SPRITE_1) igloo_block = self._load_frame_legacy("igloo_block_00.npy") igloo_door = self._load_frame_legacy("igloo_door.npy") degree_symbol = self._load_frame_legacy("degree_symbol.npy") @@ -3021,50 +3041,50 @@ def _load_and_prepare_assets(self): # Apply custom RGB colors if set by mods if self.consts.RGB_ICE_WHITE is not None: r, g, b = self.consts.RGB_ICE_WHITE - ice_wide_white = jnp.where( - ice_wide_white[..., 3:4] > 0, - jnp.concatenate([ - jnp.full_like(ice_wide_white[..., 0:1], r), - jnp.full_like(ice_wide_white[..., 1:2], g), - jnp.full_like(ice_wide_white[..., 2:3], b), - ice_wide_white[..., 3:4] - ], axis=-1), - ice_wide_white - ).astype(ice_wide_white.dtype) - ice_narrow_white = jnp.where( - ice_narrow_white[..., 3:4] > 0, - jnp.concatenate([ - jnp.full_like(ice_narrow_white[..., 0:1], r), - jnp.full_like(ice_narrow_white[..., 1:2], g), - jnp.full_like(ice_narrow_white[..., 2:3], b), - ice_narrow_white[..., 3:4] - ], axis=-1), - ice_narrow_white - ).astype(ice_narrow_white.dtype) + ice_wide_white = self._apply_custom_tint(ice_wide_white, r, g, b) + ice_narrow_white = self._apply_custom_tint(ice_narrow_white, r, g, b) if self.consts.RGB_ICE_BLUE is not None: r, g, b = self.consts.RGB_ICE_BLUE - ice_wide_blue = jnp.where( - ice_wide_blue[..., 3:4] > 0, - jnp.concatenate([ - jnp.full_like(ice_wide_blue[..., 0:1], r), - jnp.full_like(ice_wide_blue[..., 1:2], g), - jnp.full_like(ice_wide_blue[..., 2:3], b), - ice_wide_blue[..., 3:4] - ], axis=-1), - ice_wide_blue - ).astype(ice_wide_blue.dtype) - ice_narrow_blue = jnp.where( - ice_narrow_blue[..., 3:4] > 0, - jnp.concatenate([ - jnp.full_like(ice_narrow_blue[..., 0:1], r), - jnp.full_like(ice_narrow_blue[..., 1:2], g), - jnp.full_like(ice_narrow_blue[..., 2:3], b), - ice_narrow_blue[..., 3:4] - ], axis=-1), - ice_narrow_blue - ).astype(ice_narrow_blue.dtype) - + ice_wide_blue = self._apply_custom_tint(ice_wide_blue, r, g, b) + ice_narrow_blue = self._apply_custom_tint(ice_narrow_blue, r, g, b) + + if self.consts.RGB_GEESE is not None: + r, g, b = self.consts.RGB_GEESE + geese_0 = self._apply_custom_tint(geese_0, r, g, b) + geese_1 = self._apply_custom_tint(geese_1, r, g, b) + + if self.consts.RGB_FISH is not None: + r, g, b = self.consts.RGB_FISH + fish_0 = self._apply_custom_tint(fish_0, r, g, b) + fish_1 = self._apply_custom_tint(fish_1, r, g, b) + + if self.consts.RGB_CRAB is not None: + r, g, b = self.consts.RGB_CRAB + crab_0 = self._apply_custom_tint(crab_0, r, g, b) + crab_1 = self._apply_custom_tint(crab_1, r, g, b) + + if self.consts.RGB_CLAM is not None: + r, g, b = self.consts.RGB_CLAM + clam_0 = self._apply_custom_tint(clam_0, r, g, b) + clam_1 = self._apply_custom_tint(clam_1, r, g, b) + + if self.consts.RGB_IGLOO is not None: + r, g, b = self.consts.RGB_IGLOO + igloo_block = self._apply_custom_tint(igloo_block, r, g, b) + # The door is black (0,0,0); tinting it makes it disappear into the igloo blocks. + # We leave igloo_door as-is. + + if self.consts.RGB_NIGHT is not None: + r, g, b = self.consts.RGB_NIGHT + bg_night = self._apply_custom_tint(bg_night, r, g, b) + bg_day = self._apply_custom_tint(bg_day, r, g, b) + + if self.consts.DRAW_SHORE_LINE: + line_color = jnp.array([255, 255, 255, 255], dtype=jnp.uint8) + bg_night = bg_night.at[78, :].set(line_color) + bg_day = bg_day.at[78, :].set(line_color) + # Bear (Lightened for Night) bear_0_light = self._lighten_bear(bear_0) bear_1_light = self._lighten_bear(bear_1) @@ -3240,7 +3260,7 @@ def _build_strip(block_mask, n_blocks, spacing): import numpy as _np _CH, _CW = 24, 32 - _cx0, _cy0 = 111, 35 # canvas origin in raster coords + _cx0, _cy0 = 111 + self.consts.IGLOO_X_OFFSET, 35 # canvas origin in raster coords _bm = _np.array(self.IGLOO_BLOCK_MASK) # (8,8) palette IDs _dm = _np.array(self.IGLOO_DOOR_MASK) # (8,8) palette IDs @@ -3304,6 +3324,20 @@ def _decode_sprite_duplication(consts, code: jnp.ndarray): # --- Tinting helpers (used only in __init__) --- + @staticmethod + def _apply_custom_tint(sprite, r, g, b): + """Apply a custom RGB tint to a sprite, preserving alpha.""" + return jnp.where( + sprite[..., 3:4] > 0, + jnp.concatenate([ + jnp.full_like(sprite[..., 0:1], r), + jnp.full_like(sprite[..., 1:2], g), + jnp.full_like(sprite[..., 2:3], b), + sprite[..., 3:4] + ], axis=-1), + sprite + ).astype(sprite.dtype) + @staticmethod def _apply_ice_color(block_sprite, is_blue): """Apply color tinting to ice block sprites.""" @@ -3557,7 +3591,7 @@ def _render_polar_grizzly(self, raster, state): """Render the polar grizzly (bear) when active.""" should_render = state.polar_grizzly_active == 1 def draw_bear(r): - is_night = ((state.level - 1) // 4) % 2 == 1 + is_night = jnp.logical_or(((state.level - 1) // 4) % 2 == 1, self.consts.CONSTANT_NIGHT) bear_stack = jax.lax.select(is_night, self.BEAR_LIGHT_MASKS, self.BEAR_MASKS) @@ -3583,7 +3617,7 @@ def render(self, state: FrostbiteState) -> jnp.ndarray: # 1. Render background (day/night) raster = self.jr.create_object_raster(self.BACKGROUND).astype(self.PALETTE.dtype) - is_night = ((state.level - 1) // 4) % 2 == 1 + is_night = jnp.logical_or(((state.level - 1) // 4) % 2 == 1, self.consts.CONSTANT_NIGHT) raster = jax.lax.cond( is_night, lambda r: self.jr.render_at(r, 0, 0, self.SHAPE_MASKS['background_night']), diff --git a/src/jaxatari/games/mods/freeway/freeway_mod_plugins.py b/src/jaxatari/games/mods/freeway/freeway_mod_plugins.py index 28c62135b..10b91e09f 100644 --- a/src/jaxatari/games/mods/freeway/freeway_mod_plugins.py +++ b/src/jaxatari/games/mods/freeway/freeway_mod_plugins.py @@ -173,25 +173,62 @@ def after_reset(self, obs, state: FreewayState): _recolored_bikes.append(_jr.perform_recoloring(_bike_array, _rule)) +class FrogMod(JaxAtariInternalModPlugin): + """Replaces the player sprites with frog sprites.""" + asset_overrides = { + "player": { + 'name': 'player', 'type': 'group', + 'files': ['frog_hit.npy', 'frog_walk.npy', 'frog_idle.npy'] + } + } + + class BikesMod(JaxAtariInternalModPlugin): """Replaces all cars with uniquely colored bike sprites.""" - constants_overrides = { - "ASSET_CONFIG": ( - {'name': 'background', 'type': 'background', 'file': 'background.npy'}, - { - 'name': 'player', 'type': 'group', - 'files': ['player_hit.npy', 'player_walk.npy', 'player_idle.npy'] - }, - {'name': 'car_dark_red', 'type': 'procedural', 'data': _recolored_bikes[0]}, - {'name': 'car_light_green', 'type': 'procedural', 'data': _recolored_bikes[1]}, - {'name': 'car_dark_green', 'type': 'procedural', 'data': _recolored_bikes[2]}, - {'name': 'car_light_red', 'type': 'procedural', 'data': _recolored_bikes[3]}, - {'name': 'car_blue', 'type': 'procedural', 'data': _recolored_bikes[4]}, - {'name': 'car_brown', 'type': 'procedural', 'data': _recolored_bikes[5]}, - {'name': 'car_light_blue', 'type': 'procedural', 'data': _recolored_bikes[6]}, - {'name': 'car_red', 'type': 'procedural', 'data': _recolored_bikes[7]}, - {'name': 'car_green', 'type': 'procedural', 'data': _recolored_bikes[8]}, - {'name': 'car_yellow', 'type': 'procedural', 'data': _recolored_bikes[9]}, - {'name': 'score_digits', 'type': 'digits', 'pattern': 'score_{}.npy'}, - ) + asset_overrides = { + 'car_dark_red': {'name': 'car_dark_red', 'type': 'procedural', 'data': _recolored_bikes[0]}, + 'car_light_green': {'name': 'car_light_green', 'type': 'procedural', 'data': _recolored_bikes[1]}, + 'car_dark_green': {'name': 'car_dark_green', 'type': 'procedural', 'data': _recolored_bikes[2]}, + 'car_light_red': {'name': 'car_light_red', 'type': 'procedural', 'data': _recolored_bikes[3]}, + 'car_blue': {'name': 'car_blue', 'type': 'procedural', 'data': _recolored_bikes[4]}, + 'car_brown': {'name': 'car_brown', 'type': 'procedural', 'data': _recolored_bikes[5]}, + 'car_light_blue': {'name': 'car_light_blue', 'type': 'procedural', 'data': _recolored_bikes[6]}, + 'car_red': {'name': 'car_red', 'type': 'procedural', 'data': _recolored_bikes[7]}, + 'car_green': {'name': 'car_green', 'type': 'procedural', 'data': _recolored_bikes[8]}, + 'car_yellow': {'name': 'car_yellow', 'type': 'procedural', 'data': _recolored_bikes[9]}, + } + + +_bg_path = os.path.join(get_base_sprite_dir(), "freeway", "background.npy") +_bg_array = _jr.loadFrame(_bg_path) + +_lane_color_rule = [ + {'source': (214, 214, 214), 'target': (0, 0, 0)}, # Lane separation black + {'source': (252, 252, 84), 'target': (255, 0, 0)} # Double lane separation red +] +_recolored_bg = _jr.perform_recoloring(_bg_array, _lane_color_rule) + +class NewLaneColorsMod(JaxAtariInternalModPlugin): + """Makes the lane separation black and the double lane separation red.""" + asset_overrides = { + 'background': { + 'name': 'background', + 'type': 'background', + 'data': _recolored_bg + } + } + +_score_paths = [os.path.join(get_base_sprite_dir(), "freeway", f"score_{i}.npy") for i in range(10)] +_score_array = _jr._load_and_pad_digits_from_paths(_score_paths) +_green_score_rule = [{'source': (228, 111, 111), 'target': (0, 255, 0)}] +_recolored_score = _jr.perform_recoloring(_score_array, _green_score_rule) + +class GreenScoreMod(JaxAtariInternalModPlugin): + """Makes the score digits green.""" + asset_overrides = { + 'score_digits': { + 'name': 'score_digits', + 'type': 'digits', + 'data': _recolored_score + } } diff --git a/src/jaxatari/games/mods/freeway_mods.py b/src/jaxatari/games/mods/freeway_mods.py index bba28562f..0fc50577e 100644 --- a/src/jaxatari/games/mods/freeway_mods.py +++ b/src/jaxatari/games/mods/freeway_mods.py @@ -1,6 +1,6 @@ import os from jaxatari.modification import JaxAtariModController -from jaxatari.games.mods.freeway.freeway_mod_plugins import StopAllCarsMod, StaticCarsMod, SlowCarsMod, BlackCarsMod, CenterCarsOnResetMod, InvertSpeed, HallOfFameMod, BikesMod +from jaxatari.games.mods.freeway.freeway_mod_plugins import StopAllCarsMod, StaticCarsMod, SlowCarsMod, BlackCarsMod, CenterCarsOnResetMod, InvertSpeed, HallOfFameMod, BikesMod, FrogMod, NewLaneColorsMod, GreenScoreMod class FreewayEnvMod(JaxAtariModController): """ @@ -18,6 +18,10 @@ class FreewayEnvMod(JaxAtariModController): "hall_of_fame": ["_hall_of_fame_start", "static_cars"], "_hall_of_fame_start": HallOfFameMod, "bikes": BikesMod, + "frog": FrogMod, + "new_lane_colors": NewLaneColorsMod, + "green_score": GreenScoreMod, + "change_sprites": ["frog", "bikes", "new_lane_colors", "green_score"], } _mod_sprite_dir = os.path.join(os.path.dirname(__file__), "freeway", "sprites") diff --git a/src/jaxatari/games/mods/frostbite/frostbite_mod_plugins.py b/src/jaxatari/games/mods/frostbite/frostbite_mod_plugins.py index b036b14f8..0269adc03 100644 --- a/src/jaxatari/games/mods/frostbite/frostbite_mod_plugins.py +++ b/src/jaxatari/games/mods/frostbite/frostbite_mod_plugins.py @@ -36,18 +36,18 @@ def _spawn_obstacles_vec(self, state: FrostbiteState, spawn_mask: jnp.ndarray) - fish_alive_mask=new_fish_mask ) -class RedIceMod(JaxAtariInternalModPlugin): +class LightBlueIceMod(JaxAtariInternalModPlugin): """ - Changes the color of the ice blocks to red. + Changes the color of the ice blocks to light blue. """ # Frostbite defines colors as Atari hex codes. # We still override the logic colors so they match expected behavior, # but we ALSO provide the RGB tuples to actually tint the sprites! constants_overrides = { - "COLOR_ICE_WHITE": 0x44, # Red - "COLOR_ICE_BLUE": 0x42, # Darker Red - "RGB_ICE_WHITE": (255, 50, 50), - "RGB_ICE_BLUE": (200, 0, 0), + "COLOR_ICE_WHITE": 0x9C, # Light Blue + "COLOR_ICE_BLUE": 0x96, # Darker Blue + "RGB_ICE_WHITE": (173, 216, 230), + "RGB_ICE_BLUE": (70, 130, 180), } class _StaticIceMod(JaxAtariInternalModPlugin): @@ -126,3 +126,65 @@ def after_reset(self, obs, state: FrostbiteState): new_state = self._apply_alignment(state) new_obs = self._env._get_observation(new_state) return new_obs, new_state + + +class RecoloredObstaclesMod(JaxAtariInternalModPlugin): + """ + Changes the colors of the floating obstacles. + Fishes: Red, Geese: Yellow, Crabs: Dark Grey, Clams: Light Grey. + """ + constants_overrides = { + "RGB_FISH": (255, 0, 0), # Red + "RGB_GEESE": (255, 255, 0), # Yellow + "RGB_CRAB": (64, 64, 64), # Dark Grey + "RGB_CLAM": (192, 192, 192), # Light Grey + } + + +class TigerMod(JaxAtariInternalModPlugin): + """ + Switches the bear sprites to tiger sprites. + """ + constants_overrides = { + "BEAR_SPRITE_0": "tiger_00.npy", + "BEAR_SPRITE_1": "tiger_01.npy", + } + + +class WhiteIglooMod(JaxAtariInternalModPlugin): + """ + Makes the igloo white. + """ + constants_overrides = { + "RGB_IGLOO": (255, 255, 255), + } + + +class LeftIglooMod(JaxAtariInternalModPlugin): + """ + Places the igloo on the left side of the screen. + """ + constants_overrides = { + "IGLOO_X_OFFSET": -100, + "TARGET_IGLOO_X": 23, + } + + +class EarlyBearMod(JaxAtariInternalModPlugin): + """ + Spawns the bear from the start (Level 1). + """ + constants_overrides = { + "POLAR_GRIZZLY_LEVEL": 0, + } + + +class DarkNightMod(JaxAtariInternalModPlugin): + """ + Spawns constant night and makes the sky darker. + """ + constants_overrides = { + "CONSTANT_NIGHT": True, + "RGB_NIGHT": (20, 20, 60), # Dark blue sky + "DRAW_SHORE_LINE": True, + } diff --git a/src/jaxatari/games/mods/frostbite_mods.py b/src/jaxatari/games/mods/frostbite_mods.py index f6afda4ab..fffe84910 100644 --- a/src/jaxatari/games/mods/frostbite_mods.py +++ b/src/jaxatari/games/mods/frostbite_mods.py @@ -1,18 +1,26 @@ import os from jaxatari.modification import JaxAtariModController from jaxatari.games.mods.frostbite.frostbite_mod_plugins import ( - NoEnemiesMod, RedIceMod, _StaticIceMod, _MisalignedIceMod, _AlignedIceMod + NoEnemiesMod, LightBlueIceMod, _StaticIceMod, _MisalignedIceMod, _AlignedIceMod, RecoloredObstaclesMod, TigerMod, + WhiteIglooMod, LeftIglooMod, EarlyBearMod, DarkNightMod ) # --- The Registry --- FROSTBITE_MOD_REGISTRY = { "no_enemies": NoEnemiesMod, - "red_ice": RedIceMod, + "lightblue_ice": LightBlueIceMod, + "recolored_obstacles": RecoloredObstaclesMod, + "tiger": TigerMod, + "white_igloo": WhiteIglooMod, + "left_igloo": LeftIglooMod, + "early_bear": EarlyBearMod, + "dark_night": DarkNightMod, "_static_ice": _StaticIceMod, "_misaligned_ice": _MisalignedIceMod, "_aligned_ice": _AlignedIceMod, "static_aligned_ice": ["_static_ice", "_aligned_ice"], "static_misaligned_ice": ["_static_ice", "_misaligned_ice"], + "change_sprites": ["tiger", "white_igloo", "recolored_obstacles", "lightblue_ice"] } class FrostbiteEnvMod(JaxAtariModController): From 8ce4b6a54d71957b530f77a494335e8b0ac14086 Mon Sep 17 00:00:00 2001 From: Quentin Delfosse Date: Sun, 3 May 2026 15:28:56 +0200 Subject: [PATCH 08/28] Tennis new mods --- src/jaxatari/games/jax_tennis.py | 72 ++++++++++++++----- .../games/mods/tennis/tennis_mod_plugins.py | 57 +++++++++++++++ src/jaxatari/games/mods/tennis_mods.py | 12 ++++ 3 files changed, 125 insertions(+), 16 deletions(-) diff --git a/src/jaxatari/games/jax_tennis.py b/src/jaxatari/games/jax_tennis.py index e332dc601..da1357798 100644 --- a/src/jaxatari/games/jax_tennis.py +++ b/src/jaxatari/games/jax_tennis.py @@ -1,6 +1,6 @@ import os from functools import partial -from typing import Tuple, Dict, Any +from typing import Tuple, Dict, Any, Optional import jax.lax import jax.numpy as jnp import chex @@ -93,6 +93,12 @@ class TennisConstants(AutoDerivedConstants): ENEMY_CONST: chex.Array = struct.field(pytree_node=False, default_factory=lambda: jnp.array(1)) GAME_MIDDLE_HORIZONTAL: chex.Array = struct.field(pytree_node=False, default=None) + # Visual overrides + RGB_COURT: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_BLUE: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_RED: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_LINES: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + # Asset config baked into constants (immutable default) for asset overrides ASSET_CONFIG: tuple = struct.field(pytree_node=False, default=_get_default_asset_config()) @@ -1450,10 +1456,38 @@ def __init__(self, consts: TennisConstants = None, config: render_utils.Renderer # Get file-based assets from consts final_asset_config = list(self.consts.ASSET_CONFIG) + court_color = self.consts.RGB_COURT if self.consts.RGB_COURT is not None else (82, 126, 45) + blue_color = self.consts.RGB_BLUE if self.consts.RGB_BLUE is not None else (117, 128, 240) + red_color = self.consts.RGB_RED if self.consts.RGB_RED is not None else (240, 128, 128) + lines_color = self.consts.RGB_LINES if self.consts.RGB_LINES is not None else (214, 214, 214) + + recolor_rules = [] + if self.consts.RGB_COURT is not None: + recolor_rules.append({'source': (82, 126, 45), 'target': court_color}) + if self.consts.RGB_BLUE is not None: + recolor_rules.append({'source': (117, 128, 240), 'target': blue_color}) + if self.consts.RGB_RED is not None: + recolor_rules.append({'source': (240, 128, 128), 'target': red_color}) + if self.consts.RGB_LINES is not None: + recolor_rules.append({'source': (214, 214, 214), 'target': lines_color}) + + if recolor_rules: + background_rgba = self.jr.perform_recoloring(background_rgba, recolor_rules) + + # Recolor everything in the asset config that is file-based (like players) + for i in range(len(final_asset_config)): + if 'recolorings' not in final_asset_config[i]: + final_asset_config[i] = dict(final_asset_config[i]) + final_asset_config[i]['recolorings'] = {'mods': recolor_rules} + else: + # We won't worry about merging if they already have recolorings, + # but none do in the default asset config. + pass + # Create procedural assets procedural_colors = jnp.array([ - [117, 128, 240, 255], - [240, 128, 128, 255], + [blue_color[0], blue_color[1], blue_color[2], 255], + [red_color[0], red_color[1], red_color[2], 255], ], dtype=jnp.uint8).reshape(-1, 1, 1, 4) # Add procedural and specially-handled background assets @@ -1473,12 +1507,12 @@ def __init__(self, consts: TennisConstants = None, config: render_utils.Renderer self.FLIP_OFFSETS ) = self.jr.load_and_setup_assets(final_asset_config, self.sprite_path) # 5. Get color IDs - self.BLUE_ID = self.COLOR_TO_ID.get((117, 128, 240), 0) - self.RED_ID = self.COLOR_TO_ID.get((240, 128, 128), 0) + self.BLUE_ID = self.COLOR_TO_ID.get(blue_color, 0) + self.RED_ID = self.COLOR_TO_ID.get(red_color, 0) if self.BLUE_ID == 0 or self.RED_ID == 0: - blue_rgb = np.array([117, 128, 240], dtype=np.uint8) - red_rgb = np.array([240, 128, 128], dtype=np.uint8) + blue_rgb = np.array(blue_color, dtype=np.uint8) + red_rgb = np.array(red_color, dtype=np.uint8) palette_np = np.array(self.PALETTE) palette_rgb = palette_np[:, :3] if palette_np.shape[1] >= 3 else palette_np if self.BLUE_ID == 0: @@ -1496,17 +1530,21 @@ def swap_colors(mask): self.BG_TOP_MASK = self.BACKGROUND self.BG_BOTTOM_MASK = swap_colors(self.BG_TOP_MASK) - self.PLAYER_STACK = self.SHAPE_MASKS['player_anim'] + suffix = '_mods' if recolor_rules else '' + def get_mask(name): + return self.SHAPE_MASKS.get(name + suffix, self.SHAPE_MASKS[name]) + + self.PLAYER_STACK = get_mask('player_anim') self.ENEMY_STACK = vmap_swap(self.PLAYER_STACK) - self.PLAYER_RACKET_STACK = self.SHAPE_MASKS['racket_anim'] + self.PLAYER_RACKET_STACK = get_mask('racket_anim') self.ENEMY_RACKET_STACK = vmap_swap(self.PLAYER_RACKET_STACK) - self.PLAYER_DIGITS_STACK = vmap_swap(self.SHAPE_MASKS['digits']) - self.ENEMY_DIGITS_STACK = self.SHAPE_MASKS['digits'] + self.PLAYER_DIGITS_STACK = vmap_swap(get_mask('digits')) + self.ENEMY_DIGITS_STACK = get_mask('digits') - self.PLAYER_UI_A = swap_colors(self.SHAPE_MASKS['ui_a']) - self.ENEMY_UI_A = self.SHAPE_MASKS['ui_a'] + self.PLAYER_UI_A = swap_colors(get_mask('ui_a')) + self.ENEMY_UI_A = get_mask('ui_a') # 7. Store animation lengths self.anim_len = { 'player': self.ENEMY_STACK.shape[0], @@ -1608,10 +1646,11 @@ def render(self, state: TennisState) -> jnp.ndarray: raster = self.jr.create_object_raster(bg_mask) # 2. Render Ball Shadow + ball_shadow_mask = self.SHAPE_MASKS.get('ball_shadow_mods', self.SHAPE_MASKS['ball_shadow']) raster = self.jr.render_at_clipped( raster, state.ball_state.ball_x, state.ball_state.ball_y, - self.SHAPE_MASKS['ball_shadow'], - flip_offset=self.FLIP_OFFSETS['ball_shadow'] + ball_shadow_mask, + flip_offset=self.FLIP_OFFSETS.get('ball_shadow_mods', self.FLIP_OFFSETS['ball_shadow']) ) # 3. Render Player & Enemy @@ -1685,9 +1724,10 @@ def render(self, state: TennisState) -> jnp.ndarray: ) # 4. Render Ball + ball_mask = self.SHAPE_MASKS.get('ball_mods', self.SHAPE_MASKS['ball']) raster = self.jr.render_at_clipped( raster, state.ball_state.ball_x, state.ball_state.ball_y - state.ball_state.ball_z, - self.SHAPE_MASKS['ball'], flip_offset=self.FLIP_OFFSETS['ball'] + ball_mask, flip_offset=self.FLIP_OFFSETS.get('ball_mods', self.FLIP_OFFSETS['ball']) ) # 5. Render Score UI diff --git a/src/jaxatari/games/mods/tennis/tennis_mod_plugins.py b/src/jaxatari/games/mods/tennis/tennis_mod_plugins.py index 6bea8faf7..18778df29 100644 --- a/src/jaxatari/games/mods/tennis/tennis_mod_plugins.py +++ b/src/jaxatari/games/mods/tennis/tennis_mod_plugins.py @@ -317,3 +317,60 @@ def enemy_y_step(): new_enemy_direction, enemy_state_after_y_step.y_movement_direction ) + +class ClayCourtMod(JaxAtariInternalModPlugin): + """ + Changes the court colors to a clay court aesthetic (orange/red). + """ + constants_overrides = { + "RGB_COURT": (180, 80, 40), + } + +class GrassCourtMod(JaxAtariInternalModPlugin): + """ + Changes the court colors to a grass court aesthetic (green). + """ + constants_overrides = { + "RGB_COURT": (40, 140, 60), + } + +class HardCourtMod(JaxAtariInternalModPlugin): + """ + Changes the court colors to a hard court aesthetic (light blue and dark blue). + """ + constants_overrides = { + "RGB_COURT": (40, 80, 140), + } + +class NightMod(JaxAtariInternalModPlugin): + """ + Darkens the court for a night-time aesthetic. + """ + constants_overrides = { + "RGB_COURT": (15, 25, 15), + "RGB_BLUE": (58, 64, 120), # Dimmed original (117, 128, 240) + "RGB_RED": (120, 64, 64), # Dimmed original (240, 128, 128) + "RGB_LINES": (100, 100, 100), + } + +class GrayscaleMod(JaxAtariInternalModPlugin): + """ + Makes the court grayscale. + """ + constants_overrides = { + "RGB_COURT": (60, 60, 60), + "RGB_BLUE": (120, 120, 120), + "RGB_RED": (80, 80, 80), + "RGB_LINES": (200, 200, 200), + } + + +class InvertedColorsMod(JaxAtariInternalModPlugin): + """ + Swaps the player and opponent colors. + """ + constants_overrides = { + "RGB_BLUE": (240, 128, 128), + "RGB_RED": (117, 128, 240), + } + diff --git a/src/jaxatari/games/mods/tennis_mods.py b/src/jaxatari/games/mods/tennis_mods.py index ab031d491..9dc77e0b5 100644 --- a/src/jaxatari/games/mods/tennis_mods.py +++ b/src/jaxatari/games/mods/tennis_mods.py @@ -8,6 +8,12 @@ LazyEnemyMod, HighBounceMod, FastEnemyMod, + ClayCourtMod, + GrassCourtMod, + HardCourtMod, + NightMod, + GrayscaleMod, + InvertedColorsMod, ) class TennisEnvMod(JaxAtariModController): @@ -24,6 +30,12 @@ class TennisEnvMod(JaxAtariModController): "lazy_enemy": LazyEnemyMod, "high_bounce": HighBounceMod, "fast_enemy": FastEnemyMod, + "clay_court": ClayCourtMod, + "grass_court": GrassCourtMod, + "hard_court": HardCourtMod, + "night_mode": NightMod, + "grayscale": GrayscaleMod, + "inverted_colors": InvertedColorsMod, } _mod_sprite_dir = os.path.join(os.path.dirname(__file__), "tennis", "sprites") From ac136d862faf8be461f4afce087785e93a2a3ed4 Mon Sep 17 00:00:00 2001 From: Quentin Delfosse Date: Sun, 3 May 2026 16:57:30 +0200 Subject: [PATCH 09/28] Gravitar new mods --- src/jaxatari/games/jax_gravitar.py | 29 ++++++++++ .../mods/gravitar/gravitar_mod_plugins.py | 55 +++++++++++++++++++ src/jaxatari/games/mods/gravitar_mods.py | 8 +++ 3 files changed, 92 insertions(+) diff --git a/src/jaxatari/games/jax_gravitar.py b/src/jaxatari/games/jax_gravitar.py index c4ca726d4..2d295c7bb 100644 --- a/src/jaxatari/games/jax_gravitar.py +++ b/src/jaxatari/games/jax_gravitar.py @@ -108,6 +108,9 @@ class GravitarConstants(struct.PyTreeNode): FORCE_SPRITES: bool = struct.field(pytree_node=False, default=True) SCALE: int = struct.field(pytree_node=False, default=1) + # Visual overrides + RECOLOR_RULES: tuple = struct.field(pytree_node=False, default=()) + # Object limits MAX_BULLETS: int = struct.field(pytree_node=False, default=16) # reduced from 64 for faster compilation MAX_ENEMIES: int = struct.field(pytree_node=False, default=4) # reduced from 16 for faster compilation @@ -3643,6 +3646,14 @@ def __init__(self, width: int = None, height: int = None, consts: GravitarConsta self.FLIP_OFFSETS, ) = self.jr.load_and_setup_assets(asset_config, sprite_dir) + if self.consts.RECOLOR_RULES: + for k in list(self.SHAPE_MASKS.keys()): + if k.endswith('_mods'): + base_k = k[:-5] + self.SHAPE_MASKS[base_k] = self.SHAPE_MASKS[k] + if k in self.FLIP_OFFSETS: + self.FLIP_OFFSETS[base_k] = self.FLIP_OFFSETS[k] + SM = self.SHAPE_MASKS T = self.jr.TRANSPARENT_ID @@ -3737,6 +3748,18 @@ def _build_asset_config(self, obs_sprites: tuple) -> list: ) asset_config.append({'name': 'enemy_orange_flipped', 'type': 'procedural', 'data': orange_flipped}) + if self.consts.RECOLOR_RULES: + recolor_rules = list(self.consts.RECOLOR_RULES) + for i in range(len(asset_config)): + if asset_config[i].get('type') == 'background': + continue + asset_config[i] = dict(asset_config[i]) + if 'recolorings' not in asset_config[i]: + asset_config[i]['recolorings'] = {'mods': recolor_rules} + else: + asset_config[i]['recolorings'] = dict(asset_config[i]['recolorings']) + asset_config[i]['recolorings']['mods'] = recolor_rules + return asset_config def _build_terrain_rasters(self, obs_sprites: tuple) -> jnp.ndarray: @@ -3776,6 +3799,12 @@ def _build_terrain_rasters(self, obs_sprites: tuple) -> jnp.ndarray: if surf is None: bank.append(empty_page.copy()) continue + + if self.consts.RECOLOR_RULES: + surf = np.array(self.jr.perform_recoloring( + jnp.array(surf, dtype=jnp.uint8), + list(self.consts.RECOLOR_RULES) + )) th, tw = surf.shape[0], surf.shape[1] scale = min(GW / tw, GH / th) diff --git a/src/jaxatari/games/mods/gravitar/gravitar_mod_plugins.py b/src/jaxatari/games/mods/gravitar/gravitar_mod_plugins.py index b428e2f87..7243f639e 100644 --- a/src/jaxatari/games/mods/gravitar/gravitar_mod_plugins.py +++ b/src/jaxatari/games/mods/gravitar/gravitar_mod_plugins.py @@ -109,3 +109,58 @@ class LongRangeTractorMod(JaxAtariInternalModPlugin): constants_overrides = { "TRACTOR_BEAM_RANGE": 50.0, } + +class NeonMod(JaxAtariInternalModPlugin): + """ + Changes colors to bright neon variants. + """ + constants_overrides = { + "RECOLOR_RULES": ( + {"source": (101, 183, 217), "target": (255, 20, 147)}, + {"source": (198, 108, 58), "target": (0, 255, 0)}, + {"source": (72, 160, 72), "target": (255, 255, 0)}, + {"source": (223, 183, 85), "target": (0, 255, 255)}, + ) + } + +class RedAlertMod(JaxAtariInternalModPlugin): + """ + Makes all terrain red/orange for a high-alert aesthetic. + """ + constants_overrides = { + "RECOLOR_RULES": ( + {"source": (223, 183, 85), "target": (255, 50, 50)}, + {"source": (84, 160, 197), "target": (220, 40, 40)}, + {"source": (66, 72, 200), "target": (200, 30, 30)}, + {"source": (213, 130, 74), "target": (255, 0, 0)}, + ) + } + +class GrayscaleMod(JaxAtariInternalModPlugin): + """ + Converts the visual palette to grayscale. + """ + constants_overrides = { + "RECOLOR_RULES": ( + {"source": (223, 183, 85), "target": (150, 150, 150)}, + {"source": (84, 160, 197), "target": (120, 120, 120)}, + {"source": (66, 72, 200), "target": (80, 80, 80)}, + {"source": (228, 111, 111), "target": (140, 140, 140)}, + {"source": (213, 130, 74), "target": (160, 160, 160)}, + {"source": (101, 183, 217), "target": (220, 220, 220)}, + {"source": (198, 108, 58), "target": (110, 110, 110)}, + {"source": (72, 160, 72), "target": (90, 90, 90)}, + ) + } + +class InvertedColorsMod(JaxAtariInternalModPlugin): + """ + Inverts the primary colors. + """ + constants_overrides = { + "RECOLOR_RULES": ( + {"source": (101, 183, 217), "target": (154, 72, 38)}, + {"source": (223, 183, 85), "target": (32, 72, 170)}, + {"source": (84, 160, 197), "target": (171, 95, 58)}, + ) + } diff --git a/src/jaxatari/games/mods/gravitar_mods.py b/src/jaxatari/games/mods/gravitar_mods.py index 5e3962ae2..6f109d0ef 100644 --- a/src/jaxatari/games/mods/gravitar_mods.py +++ b/src/jaxatari/games/mods/gravitar_mods.py @@ -11,6 +11,10 @@ InfiniteFuelMod, SlowEnemiesMod, LongRangeTractorMod, + NeonMod, + RedAlertMod, + GrayscaleMod, + InvertedColorsMod, ) @@ -29,6 +33,10 @@ class GravitarEnvMod(JaxAtariModController): "infinite_fuel": InfiniteFuelMod, "slow_enemies": SlowEnemiesMod, "long_range_tractor": LongRangeTractorMod, + "neon_mode": NeonMod, + "red_alert": RedAlertMod, + "grayscale": GrayscaleMod, + "inverted_colors": InvertedColorsMod, } def __init__( From 0bbca7cdd692d93667f5d78d6929dfc48922fbe7 Mon Sep 17 00:00:00 2001 From: Quentin Delfosse Date: Sun, 3 May 2026 17:30:47 +0200 Subject: [PATCH 10/28] Qbert new mods --- src/jaxatari/games/jax_qbert.py | 75 +++++++++++++++--- .../games/mods/qbert/qbert_mod_plugins.py | 78 +++++++++++++++++++ src/jaxatari/games/mods/qbert_mods.py | 16 +++- 3 files changed, 158 insertions(+), 11 deletions(-) diff --git a/src/jaxatari/games/jax_qbert.py b/src/jaxatari/games/jax_qbert.py index b77167e06..8da4ce45b 100644 --- a/src/jaxatari/games/jax_qbert.py +++ b/src/jaxatari/games/jax_qbert.py @@ -6,7 +6,7 @@ # Simulates the Atari Q*bert game # import os -from typing import Tuple +from typing import Tuple, Optional from functools import partial import numpy as np @@ -95,6 +95,16 @@ class QbertConstants(struct.PyTreeNode): SAM_REWARD: int = struct.field(pytree_node=False, default=300) COILY_REWARD: int = struct.field(pytree_node=False, default=500) ROUND_COMPLETE_REWARD: int = struct.field(pytree_node=False, default=3100) + + # Visual overrides + RGB_BACKGROUND: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_QBERT: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_COILY: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_SAM: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_RED_BALL: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_CUBE_START: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_CUBE_INTER: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_CUBE_DEST: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) class QbertState(struct.PyTreeNode): @@ -1164,6 +1174,43 @@ def __init__(self, consts: QbertConstants = None, config: render_utils.RendererC final_asset_config = list(self.consts.ASSET_CONFIG) sprite_path = os.path.join(render_utils.get_base_sprite_dir(), "qbert") + has_recolorings = False + for i in range(len(final_asset_config)): + asset_name = final_asset_config[i]['name'] + asset_rules = [] + + if asset_name in ('background',): + if self.consts.RGB_BACKGROUND is not None: + asset_rules.append({'source': (0, 0, 0), 'target': self.consts.RGB_BACKGROUND}) + elif asset_name in ('qbert_sprites', 'qbert_live', 'dead'): + if self.consts.RGB_QBERT is not None: + asset_rules.append({'source': (181, 83, 40), 'target': self.consts.RGB_QBERT}) + elif asset_name in ('coily', 'purple_ball', 'snake'): + if self.consts.RGB_COILY is not None: + asset_rules.append({'source': (146, 70, 192), 'target': self.consts.RGB_COILY}) + elif asset_name in ('sam', 'green_ball'): + if self.consts.RGB_SAM is not None: + asset_rules.append({'source': (50, 132, 50), 'target': self.consts.RGB_SAM}) + elif asset_name in ('red_ball',): + if self.consts.RGB_RED_BALL is not None: + asset_rules.append({'source': (173, 5, 64), 'target': self.consts.RGB_RED_BALL}) + elif asset_name in ('color_start',): + if self.consts.RGB_CUBE_START is not None: + asset_rules.append({'source': (45, 87, 176), 'target': self.consts.RGB_CUBE_START}) + elif asset_name in ('color_intermediate',): + if self.consts.RGB_CUBE_INTER is not None: + asset_rules.append({'source': (110, 156, 66), 'target': self.consts.RGB_CUBE_INTER}) + elif asset_name in ('color_destination',): + if self.consts.RGB_CUBE_DEST is not None: + asset_rules.append({'source': (210, 210, 64), 'target': self.consts.RGB_CUBE_DEST}) + elif asset_name in ('win_animation',): + pass # Usually handled differently or uses multiple colors + + if asset_rules: + final_asset_config[i] = dict(final_asset_config[i]) + final_asset_config[i]['recolorings'] = {'mods': asset_rules} + has_recolorings = True + ( self.PALETTE, self.SHAPE_MASKS, @@ -1171,6 +1218,9 @@ def __init__(self, consts: QbertConstants = None, config: render_utils.RendererC self.COLOR_TO_ID, self.FLIP_OFFSETS, ) = self.jr.load_and_setup_assets(final_asset_config, sprite_path) + + # Suffix handling for masks when mods are active + self._mask_suffix = '_mods' if has_recolorings else '' freeze_rgba = self.jr.loadFrame(os.path.join(sprite_path, "freeze.npy")) self.FREEZE_BG = self.jr._create_background_raster(freeze_rgba, self.COLOR_TO_ID) @@ -1184,7 +1234,7 @@ def __init__(self, consts: QbertConstants = None, config: render_utils.RendererC # Pre-calculate Backgrounds with Shadows def render_shadows(round_idx): r = self.BACKGROUND - M = self.SHAPE_MASKS + M = {k: self.SHAPE_MASKS.get(k + self._mask_suffix, self.SHAPE_MASKS[k]) for k in self.SHAPE_MASKS if not k.endswith('_mods')} jr = self.jr r = jr.render_at(r, 68, 40, M['cube_shadow_right'][round_idx]) r = jr.render_at(r, 56, 69, M['cube_shadow_right'][round_idx]) @@ -1224,8 +1274,11 @@ def render_shadows(round_idx): self.PYRAMID_LOCAL_X = jnp.zeros((target_h, target_w), dtype=jnp.int32) # Use a dummy raster to find where each cube is + def get_mask(key): + return self.SHAPE_MASKS.get(key + self._mask_suffix, self.SHAPE_MASKS[key]) + # We need to use the masks as they are in SHAPE_MASKS (might be downscaled) - cube_mask = self.SHAPE_MASKS['color_start'] + cube_mask = get_mask('color_start') ch, cw = cube_mask.shape for idx in range(21): @@ -1261,7 +1314,7 @@ def render_shadows(round_idx): self.LIVES_MAP = jnp.full((target_h, target_w), -1, dtype=jnp.int32) self.LIVES_LOCAL_Y = jnp.zeros((target_h, target_w), dtype=jnp.int32) self.LIVES_LOCAL_X = jnp.zeros((target_h, target_w), dtype=jnp.int32) - live_mask = self.SHAPE_MASKS['qbert_live'] + live_mask = get_mask('qbert_live') lh, lw = live_mask.shape for idx in range(9): x, y = self.LIVE_POSITIONS[idx][0], self.LIVE_POSITIONS[idx][1] @@ -1284,7 +1337,7 @@ def render_shadows(round_idx): self.SCORE_MAP = jnp.full((target_h, target_w), -1, dtype=jnp.int32) self.SCORE_LOCAL_Y = jnp.zeros((target_h, target_w), dtype=jnp.int32) self.SCORE_LOCAL_X = jnp.zeros((target_h, target_w), dtype=jnp.int32) - score_masks = self.SHAPE_MASKS['score_digits'] + score_masks = get_mask('score_digits') sh, sw = score_masks[0].shape for idx in range(5): x, y = 34 + idx * 8, 6 @@ -1308,11 +1361,11 @@ def render_shadows(round_idx): # Pre-stack all cube sprites for fast vectorized lookup self.ALL_CUBE_SPRITES = jnp.stack([ - self.SHAPE_MASKS['color_start'], - self.SHAPE_MASKS['color_intermediate'], - self.SHAPE_MASKS['color_destination'] + get_mask('color_start'), + get_mask('color_intermediate'), + get_mask('color_destination') ]) # (3, 5, 20) - self.WIN_ANIMATION_SPRITES = self.SHAPE_MASKS['win_animation'] # (32, 5, 20) + self.WIN_ANIMATION_SPRITES = get_mask('win_animation') # (32, 5, 20) self.QBERT_POSITIONS = jnp.array([[74, 18], [62, 47], [86, 47], [50, 76], [74, 76], [98, 76], [38, 105], [62, 105], [86, 105], [110, 105], [26, 134], [50, 134], [74, 134], [98, 134], [122, 134], [14, 163], [38, 163], [62, 163], [86, 163], [110, 163], [134, 163]]).astype(jnp.int32) self.QBERT_MOVE_RIGHT_DOWN = jnp.array([[0, 0], [1, -1], [2, -2], [3, -3], [4, -4], [5, -5], [6, -6], [7, -5], [8, -4], [9, -3], [10, -2], [11, -1], [12, -0], [12, 1], [12, 3], [12, 5], [12, 8], [12, 10], [12, 12], [12, 14], [12, 17], [12, 19], [12, 21], [12, 23], [12, 25], [12, 27], [12, 29], [12, 29], [12, 29], [12, 29]]).astype(jnp.int32) @@ -1357,7 +1410,9 @@ def _draw_colors(self, raster: jnp.ndarray, state: QbertState, pyra: jnp.ndarray @partial(jax.jit, static_argnums=(0,)) def render(self, state: QbertState) -> jnp.ndarray: jr = self.jr - M = self.SHAPE_MASKS + def get_mask(key): + return self.SHAPE_MASKS.get(key + self._mask_suffix, self.SHAPE_MASKS[key]) + M = {k: get_mask(k) for k in self.SHAPE_MASKS if not k.endswith('_mods')} round_idx = jnp.where(state.next_round_animation_counter != 0, state.round_number - 2, state.round_number - 1) round_idx = jnp.clip(round_idx, 0, 4) diff --git a/src/jaxatari/games/mods/qbert/qbert_mod_plugins.py b/src/jaxatari/games/mods/qbert/qbert_mod_plugins.py index 3543ec638..b97fe6e03 100644 --- a/src/jaxatari/games/mods/qbert/qbert_mod_plugins.py +++ b/src/jaxatari/games/mods/qbert/qbert_mod_plugins.py @@ -593,3 +593,81 @@ class CollectingBonusOnlyMod(JaxAtariInternalModPlugin): "ROUND_COMPLETE_REWARD": 0 } + +class IcePyramidMod(JaxAtariInternalModPlugin): + """ + Changes the pyramid to an icy aesthetic. + """ + constants_overrides = { + "RGB_CUBE_START": (173, 216, 230), + "RGB_CUBE_INTER": (135, 206, 235), + "RGB_CUBE_DEST": (70, 130, 180), + } + +class DarkPyramidMod(JaxAtariInternalModPlugin): + """ + Changes the pyramid to a dark/obsidian aesthetic. + """ + constants_overrides = { + "RGB_CUBE_START": (40, 40, 40), + "RGB_CUBE_INTER": (80, 80, 80), + "RGB_CUBE_DEST": (20, 20, 20), + "RGB_BACKGROUND": (15, 15, 15), + } + +class NightMod(JaxAtariInternalModPlugin): + """ + Dims the background and the characters for a night-time aesthetic. + """ + constants_overrides = { + "RGB_BACKGROUND": (10, 10, 20), + "RGB_QBERT": (90, 41, 20), + "RGB_COILY": (73, 35, 96), + "RGB_SAM": (25, 66, 25), + "RGB_CUBE_START": (22, 43, 88), + "RGB_CUBE_INTER": (55, 78, 33), + "RGB_CUBE_DEST": (105, 105, 32), + } + +class GrayscaleMod(JaxAtariInternalModPlugin): + """ + Makes the entire game grayscale. + """ + constants_overrides = { + "RGB_QBERT": (120, 120, 120), + "RGB_COILY": (100, 100, 100), + "RGB_SAM": (80, 80, 80), + "RGB_CUBE_START": (60, 60, 60), + "RGB_CUBE_INTER": (120, 120, 120), + "RGB_CUBE_DEST": (180, 180, 180), + "RGB_BACKGROUND": (20, 20, 20), + } + +class InvertedColorsMod(JaxAtariInternalModPlugin): + """ + Swaps Q*bert's and Coily's colors. + """ + constants_overrides = { + "RGB_QBERT": (146, 70, 192), + "RGB_COILY": (181, 83, 40), + } + + +class SwapCollectiblesEnemiesMod(JaxAtariInternalModPlugin): + """ + Swaps the colors of the collectibles (Green Ball) and Enemies (Red Ball/Coily). + """ + constants_overrides = { + "RGB_SAM": (173, 5, 64), + "RGB_RED_BALL": (50, 132, 50), + "RGB_COILY": (50, 132, 50), + } + +class RedCoilyMod(JaxAtariInternalModPlugin): + """ + Makes Coily red. + """ + constants_overrides = { + "RGB_COILY": (173, 5, 64), + } + diff --git a/src/jaxatari/games/mods/qbert_mods.py b/src/jaxatari/games/mods/qbert_mods.py index ff97cb70d..f5e4d2dfd 100644 --- a/src/jaxatari/games/mods/qbert_mods.py +++ b/src/jaxatari/games/mods/qbert_mods.py @@ -8,7 +8,14 @@ NoEnemiesMod, DiagonalControlMod, SwapColorsMod, - CollectingBonusOnlyMod + CollectingBonusOnlyMod, + IcePyramidMod, + DarkPyramidMod, + NightMod, + GrayscaleMod, + InvertedColorsMod, + SwapCollectiblesEnemiesMod, + RedCoilyMod ) class QbertEnvMod(JaxAtariModController): @@ -25,6 +32,13 @@ class QbertEnvMod(JaxAtariModController): "diagonal_control": DiagonalControlMod, "swap_colors": SwapColorsMod, "collecting_bonus_only": CollectingBonusOnlyMod, + "ice_pyramid": IcePyramidMod, + "dark_pyramid": DarkPyramidMod, + "night_mode": NightMod, + "grayscale": GrayscaleMod, + "inverted_colors": InvertedColorsMod, + "swap_collectibles_enemies": SwapCollectiblesEnemiesMod, + "red_coily": RedCoilyMod, } def __init__(self, From a40a8ef3d9090490c46e42057be195c0ca76ceb4 Mon Sep 17 00:00:00 2001 From: Quentin Delfosse Date: Sun, 3 May 2026 18:22:08 +0200 Subject: [PATCH 11/28] Venture new mods --- src/jaxatari/games/jax_venture.py | 172 ++++++++++++++---- .../games/mods/venture/venture_mod_plugins.py | 132 +++++++++----- src/jaxatari/games/mods/venture_mods.py | 48 +++-- 3 files changed, 248 insertions(+), 104 deletions(-) diff --git a/src/jaxatari/games/jax_venture.py b/src/jaxatari/games/jax_venture.py index 4acc7da7f..eb4411d25 100644 --- a/src/jaxatari/games/jax_venture.py +++ b/src/jaxatari/games/jax_venture.py @@ -1,6 +1,6 @@ import os from functools import partial -from typing import Dict, Any, Tuple +from typing import Dict, Any, Tuple, Optional from functools import lru_cache import jax @@ -433,6 +433,24 @@ class VentureConstants(AutoDerivedConstants): default_factory=lambda: jnp.zeros((2, 1), dtype=jnp.int32) ) + # Base Colors + RGB_BACKGROUND: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_PLAYER_DETAILED: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_W1_WALLS: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_W2_WALLS: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + + # Monsters + RGB_MONSTER_W1_MAP: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_MONSTER_W1_R2: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_MONSTER_W1_R3: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_MONSTER_W1_R4: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + + RGB_MONSTER_W2_MAP: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_MONSTER_W2_R1: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_MONSTER_W2_R2: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_MONSTER_W2_R3: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_MONSTER_W2_R4: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + class LaserState(struct.PyTreeNode): """Holds the state of the moving laser walls.""" @@ -525,6 +543,20 @@ def __init__(self, consts: VentureConstants = None): JAX_TRANSITIONS=static_data["jax_transitions"], MAIN_MAP_PORTAL_MASKS=static_data["main_map_portal_masks"], MAIN_MAP_PORTAL_TO_LEVELS=static_data["main_map_portal_to_levels"], + + RGB_BACKGROUND=base_consts.RGB_BACKGROUND, + RGB_PLAYER_DETAILED=base_consts.RGB_PLAYER_DETAILED, + RGB_W1_WALLS=base_consts.RGB_W1_WALLS, + RGB_W2_WALLS=base_consts.RGB_W2_WALLS, + RGB_MONSTER_W1_MAP=base_consts.RGB_MONSTER_W1_MAP, + RGB_MONSTER_W1_R2=base_consts.RGB_MONSTER_W1_R2, + RGB_MONSTER_W1_R3=base_consts.RGB_MONSTER_W1_R3, + RGB_MONSTER_W1_R4=base_consts.RGB_MONSTER_W1_R4, + RGB_MONSTER_W2_MAP=base_consts.RGB_MONSTER_W2_MAP, + RGB_MONSTER_W2_R1=base_consts.RGB_MONSTER_W2_R1, + RGB_MONSTER_W2_R2=base_consts.RGB_MONSTER_W2_R2, + RGB_MONSTER_W2_R3=base_consts.RGB_MONSTER_W2_R3, + RGB_MONSTER_W2_R4=base_consts.RGB_MONSTER_W2_R4, ) super().__init__(initialized_consts) self.renderer = VentureRenderer(self.consts) @@ -1797,7 +1829,67 @@ def __init__(self, consts: VentureConstants = None, config: render_utils.Rendere # Add procedural sprites (projectiles, lasers) procedural_assets = self._create_procedural_assets(sprite_path) asset_config.extend(procedural_assets) - + + # Inject Recoloring Rules based on Constants + has_recolorings = False + for i in range(len(asset_config)): + asset_name = asset_config[i]['name'] + asset_rules = [] + + if asset_name == 'background': + if self.consts.RGB_BACKGROUND is not None: + asset_rules.append({'source': (0, 0, 0), 'target': self.consts.RGB_BACKGROUND}) + elif asset_name in ('map_w1', 'room1_w1', 'room2_w1', 'room3_w1', 'room4_w1', 'player_dot_w1', 'health_w1'): + if self.consts.RGB_W1_WALLS is not None: + asset_rules.append({'source': (168, 48, 143), 'target': self.consts.RGB_W1_WALLS}) + elif asset_name in ('map_w2', 'room1_w2', 'room2_w2', 'room3_w2', 'room4_w2', 'player_dot_w2', 'health_w2'): + if self.consts.RGB_W2_WALLS is not None: + asset_rules.append({'source': (45, 87, 176), 'target': self.consts.RGB_W2_WALLS}) + elif asset_name in ('player_detailed', 'projectile_resized'): + if self.consts.RGB_PLAYER_DETAILED is not None: + asset_rules.append({'source': (167, 26, 26), 'target': self.consts.RGB_PLAYER_DETAILED}) + elif asset_name in ('monster_map_w1', 'monster_dead_map_w1', 'chaser'): + if self.consts.RGB_MONSTER_W1_MAP is not None: + asset_rules.append({'source': (82, 126, 45), 'target': self.consts.RGB_MONSTER_W1_MAP}) + elif asset_name in ('monster_r2_w1', 'monster_dead_r2_w1'): + if self.consts.RGB_MONSTER_W1_R2 is not None: + asset_rules.append({'source': (82, 126, 45), 'target': self.consts.RGB_MONSTER_W1_R2}) + elif asset_name in ('monster_r3_w1', 'monster_dead_r3_w1'): + if self.consts.RGB_MONSTER_W1_R3 is not None: + asset_rules.append({'source': (78, 50, 181), 'target': self.consts.RGB_MONSTER_W1_R3}) + elif asset_name in ('monster_r4_w1', 'monster_dead_r4_w1'): + if self.consts.RGB_MONSTER_W1_R4 is not None: + asset_rules.append({'source': (111, 111, 111), 'target': self.consts.RGB_MONSTER_W1_R4}) + elif asset_name in ('monster_map_w2', 'monster_dead_map_w2', 'laser_ho', 'laser_ve', 'laser_ve_stretched', 'laser_ho_stretched'): + if self.consts.RGB_MONSTER_W2_MAP is not None: + asset_rules.append({'source': (181, 83, 40), 'target': self.consts.RGB_MONSTER_W2_MAP}) + elif asset_name in ('monster_r1_w2', 'monster_dead_r1_w2'): + if self.consts.RGB_MONSTER_W2_R1 is not None: + asset_rules.append({'source': (184, 50, 50), 'target': self.consts.RGB_MONSTER_W2_R1}) + elif asset_name in ('monster_r2_w2', 'monster_dead_r2_w2'): + if self.consts.RGB_MONSTER_W2_R2 is not None: + asset_rules.append({'source': (111, 111, 111), 'target': self.consts.RGB_MONSTER_W2_R2}) + elif asset_name in ('monster_r3_w2', 'monster_dead_r3_w2'): + if self.consts.RGB_MONSTER_W2_R3 is not None: + asset_rules.append({'source': (134, 134, 29), 'target': self.consts.RGB_MONSTER_W2_R3}) + elif asset_name in ('monster_r4_w2', 'monster_dead_r4_w2'): + if self.consts.RGB_MONSTER_W2_R4 is not None: + asset_rules.append({'source': (181, 83, 40), 'target': self.consts.RGB_MONSTER_W2_R4}) + + # Additional catches for text and rewards + elif asset_name in ('digits',): + if self.consts.RGB_PLAYER_DETAILED is not None: # Use player color for text as it's typically prominent + asset_rules.append({'source': (170, 170, 170), 'target': self.consts.RGB_PLAYER_DETAILED}) + elif asset_name.startswith('reward'): + if self.consts.RGB_PLAYER_DETAILED is not None: # Turn rewards to player color + # Global replace for simplicity on rewards + asset_rules.append({'target': self.consts.RGB_PLAYER_DETAILED}) + + if asset_rules: + asset_config[i] = dict(asset_config[i]) + asset_config[i]['recolorings'] = {'mods': asset_rules} + has_recolorings = True + ( self.PALETTE, self.SHAPE_MASKS, @@ -1805,6 +1897,11 @@ def __init__(self, consts: VentureConstants = None, config: render_utils.Rendere self.COLOR_TO_ID, self.FLIP_OFFSETS ) = self.jr.load_and_setup_assets(asset_config, sprite_path) + + self._mask_suffix = '_mods' if has_recolorings else '' + + def get_mask(key): + return self.SHAPE_MASKS.get(key + self._mask_suffix, self.SHAPE_MASKS[key]) # --- Pre-stack masks for efficient indexing (Avoids Switches/Conds in render) --- @@ -1817,8 +1914,8 @@ def stack_and_pad(masks): # 1. Wall Masks -> Pre-baked Background Rasters # We stamp the wall masks onto the base background raster once during __init__ all_wall_masks = jnp.stack([ - stack_and_pad([self.SHAPE_MASKS['map_w1'], self.SHAPE_MASKS['room1_w1'], self.SHAPE_MASKS['room2_w1'], self.SHAPE_MASKS['room3_w1'], self.SHAPE_MASKS['room4_w1']]), - stack_and_pad([self.SHAPE_MASKS['map_w2'], self.SHAPE_MASKS['room1_w2'], self.SHAPE_MASKS['room2_w2'], self.SHAPE_MASKS['room3_w2'], self.SHAPE_MASKS['room4_w2']]) + stack_and_pad([get_mask('map_w1'), get_mask('room1_w1'), get_mask('room2_w1'), get_mask('room3_w1'), get_mask('room4_w1')]), + stack_and_pad([get_mask('map_w2'), get_mask('room1_w2'), get_mask('room2_w2'), get_mask('room3_w2'), get_mask('room4_w2')]) ]) base_raster = self.jr.create_object_raster(self.BACKGROUND) self.all_background_rasters = jax.vmap(jax.vmap(lambda m: self.jr.render_at(base_raster, 0, 0, m)))(all_wall_masks) @@ -1827,48 +1924,48 @@ def stack_and_pad(masks): # Note: Room 1 in W1 uses map monster. self.all_monster_masks = jnp.stack([ stack_and_pad([ - self.SHAPE_MASKS['monster_map_w1'], - self.SHAPE_MASKS['monster_map_w1'], - self.SHAPE_MASKS['monster_r2_w1'], - self.SHAPE_MASKS['monster_r3_w1'], - self.SHAPE_MASKS['monster_r4_w1'] + get_mask('monster_map_w1'), + get_mask('monster_map_w1'), + get_mask('monster_r2_w1'), + get_mask('monster_r3_w1'), + get_mask('monster_r4_w1') ]), stack_and_pad([ - self.SHAPE_MASKS['monster_map_w2'], - self.SHAPE_MASKS['monster_r1_w2'], - self.SHAPE_MASKS['monster_r2_w2'], - self.SHAPE_MASKS['monster_r3_w2'], - self.SHAPE_MASKS['monster_r4_w2'] + get_mask('monster_map_w2'), + get_mask('monster_r1_w2'), + get_mask('monster_r2_w2'), + get_mask('monster_r3_w2'), + get_mask('monster_r4_w2') ]) ]) # 3. Dead Monster Masks: (World, Level, H, W) self.all_dead_monster_masks = jnp.stack([ stack_and_pad([ - self.SHAPE_MASKS['monster_dead_map_w1'], - self.SHAPE_MASKS['monster_dead_map_w1'], - self.SHAPE_MASKS['monster_dead_r2_w1'], - self.SHAPE_MASKS['monster_dead_r3_w1'], - self.SHAPE_MASKS['monster_dead_r4_w1'] + get_mask('monster_dead_map_w1'), + get_mask('monster_dead_map_w1'), + get_mask('monster_dead_r2_w1'), + get_mask('monster_dead_r3_w1'), + get_mask('monster_dead_r4_w1') ]), stack_and_pad([ - self.SHAPE_MASKS['monster_dead_map_w2'], - self.SHAPE_MASKS['monster_dead_r1_w2'], - self.SHAPE_MASKS['monster_dead_r2_w2'], - self.SHAPE_MASKS['monster_dead_r3_w2'], - self.SHAPE_MASKS['monster_dead_r4_w2'] + get_mask('monster_dead_map_w2'), + get_mask('monster_dead_r1_w2'), + get_mask('monster_dead_r2_w2'), + get_mask('monster_dead_r3_w2'), + get_mask('monster_dead_r4_w2') ]) ]) # 4. Chest Masks: (World, Room, H, W) -> Rooms 1-4 (indices 0-3) self.all_chest_masks = jnp.stack([ - stack_and_pad([self.SHAPE_MASKS['reward1_w1'], self.SHAPE_MASKS['reward2_w1'], self.SHAPE_MASKS['reward3_w1'], self.SHAPE_MASKS['reward4_w1']]), - stack_and_pad([self.SHAPE_MASKS['reward1_w2'], self.SHAPE_MASKS['reward2_w2'], self.SHAPE_MASKS['reward3_w2'], self.SHAPE_MASKS['reward4_w2']]) + stack_and_pad([get_mask('reward1_w1'), get_mask('reward2_w1'), get_mask('reward3_w1'), get_mask('reward4_w1')]), + stack_and_pad([get_mask('reward1_w2'), get_mask('reward2_w2'), get_mask('reward3_w2'), get_mask('reward4_w2')]) ]) # 5. UI and Player Masks - self.all_life_masks = stack_and_pad([self.SHAPE_MASKS['health_w1'], self.SHAPE_MASKS['health_w2']]) - self.all_player_dot_masks = stack_and_pad([self.SHAPE_MASKS['player_dot_w1'], self.SHAPE_MASKS['player_dot_w2']]) + self.all_life_masks = stack_and_pad([get_mask('health_w1'), get_mask('health_w2')]) + self.all_player_dot_masks = stack_and_pad([get_mask('player_dot_w1'), get_mask('player_dot_w2')]) # Precompute static offsets self.monster_offsets = jnp.array([self.consts.MONSTER_RENDER_WIDTH / 2, self.consts.MONSTER_RENDER_HEIGHT / 2], dtype=jnp.int32) @@ -1906,6 +2003,9 @@ def load_resize(filename, target_shape, name): @partial(jax.jit, static_argnums=(0,)) def render(self, state): """Renders the game state to an RGBA image array.""" + def get_mask(key): + return self.SHAPE_MASKS.get(key + self._mask_suffix, self.SHAPE_MASKS[key]) + world_idx = state.world_level - 1 level_idx = state.current_level is_in_room = level_idx > 0 @@ -1915,7 +2015,7 @@ def render(self, state): # --- 2. Score and Lives --- score_digits = self.jr.int_to_digits(state.score, max_digits=6) - canvas = self.jr.render_label(canvas, 8, 10, score_digits, self.SHAPE_MASKS['digits'], spacing=6, max_digits=6) + canvas = self.jr.render_label(canvas, 8, 10, score_digits, get_mask('digits'), spacing=6, max_digits=6) life_mask = self.all_life_masks[world_idx] canvas = self.jr.render_indicator(canvas, 120, 10, state.lives - 1, life_mask, spacing=10, max_value=3) @@ -1978,7 +2078,7 @@ def draw_single_dead_monster(i, _c): chaser_tl = (jnp.array([state.chaser.x, state.chaser.y]) - self.chaser_offsets).astype(jnp.int32) canvas = jax.lax.cond( state.chaser.active, - lambda c: self.jr.render_at(c, chaser_tl[0], chaser_tl[1], self.SHAPE_MASKS['chaser']), + lambda c: self.jr.render_at(c, chaser_tl[0], chaser_tl[1], get_mask('chaser')), lambda c: c, canvas ) @@ -1988,10 +2088,10 @@ def draw_lasers(c): x_span_start, _, y_span_start, _ = self.consts.LASER_ROOM_SPAN thick_h = self.consts.LASER_THICKNESS / 2 - c = self.jr.render_at(c, (state.lasers.positions[0] - thick_h).astype(jnp.int32), y_span_start.astype(jnp.int32), self.SHAPE_MASKS['laser_ve_stretched']) - c = self.jr.render_at(c, (state.lasers.positions[1] - thick_h).astype(jnp.int32), y_span_start.astype(jnp.int32), self.SHAPE_MASKS['laser_ve_stretched']) - c = self.jr.render_at(c, x_span_start.astype(jnp.int32), (state.lasers.positions[2] - thick_h).astype(jnp.int32), self.SHAPE_MASKS['laser_ho_stretched']) - c = self.jr.render_at(c, x_span_start.astype(jnp.int32), (state.lasers.positions[3] - thick_h).astype(jnp.int32), self.SHAPE_MASKS['laser_ho_stretched']) + c = self.jr.render_at(c, (state.lasers.positions[0] - thick_h).astype(jnp.int32), y_span_start.astype(jnp.int32), get_mask('laser_ve_stretched')) + c = self.jr.render_at(c, (state.lasers.positions[1] - thick_h).astype(jnp.int32), y_span_start.astype(jnp.int32), get_mask('laser_ve_stretched')) + c = self.jr.render_at(c, x_span_start.astype(jnp.int32), (state.lasers.positions[2] - thick_h).astype(jnp.int32), get_mask('laser_ho_stretched')) + c = self.jr.render_at(c, x_span_start.astype(jnp.int32), (state.lasers.positions[3] - thick_h).astype(jnp.int32), get_mask('laser_ho_stretched')) return c canvas = jax.lax.cond((level_idx == 1) & (state.world_level == 1), draw_lasers, lambda c: c, canvas) @@ -2001,7 +2101,7 @@ def draw_player(c): def _room(_c): px = (state.player.x - self.player_detailed_offsets[0]).astype(jnp.int32) py = (state.player.y - self.player_detailed_offsets[1]).astype(jnp.int32) - return self.jr.render_at(_c, px, py, self.SHAPE_MASKS['player_detailed']) + return self.jr.render_at(_c, px, py, get_mask('player_detailed')) def _map(_c): mask = self.all_player_dot_masks[world_idx] px = (state.player.x - self.player_dot_offsets[0]).astype(jnp.int32) @@ -2023,7 +2123,7 @@ def draw_aiming_dot(c): def draw_projectile(c): px = (state.projectile.x - self.consts.PROJECTILE_RADIUS).astype(jnp.int32) py = (state.projectile.y - self.consts.PROJECTILE_RADIUS).astype(jnp.int32) - return self.jr.render_at(c, px, py, self.SHAPE_MASKS['projectile_resized']) + return self.jr.render_at(c, px, py, get_mask('projectile_resized')) def draw_room_extras(c): return jax.lax.cond(state.projectile.active, draw_projectile, draw_aiming_dot, c) diff --git a/src/jaxatari/games/mods/venture/venture_mod_plugins.py b/src/jaxatari/games/mods/venture/venture_mod_plugins.py index ef0f225c5..1f6eb597e 100644 --- a/src/jaxatari/games/mods/venture/venture_mod_plugins.py +++ b/src/jaxatari/games/mods/venture/venture_mod_plugins.py @@ -1,58 +1,104 @@ +from typing import Dict, Any, Tuple from jaxatari.modification import JaxAtariInternalModPlugin -import jax.numpy as jnp - - -class FastWinkyMod(JaxAtariInternalModPlugin): - """Increase Winky's movement speed.""" +class NightMod(JaxAtariInternalModPlugin): + """Dims the entire screen by 50% for a night mode experience.""" + name = "night_mode" + + # 50% of the original values based on user constraints constants_overrides = { - "PLAYER_SPEED": 2.0, + 'RGB_BACKGROUND': (0, 0, 0), + 'RGB_PLAYER_DETAILED': (83, 13, 13), + 'RGB_W1_WALLS': (84, 24, 71), + 'RGB_W2_WALLS': (22, 43, 88), + 'RGB_MONSTER_W1_MAP': (41, 63, 22), + 'RGB_MONSTER_W1_R2': (41, 63, 22), + 'RGB_MONSTER_W1_R3': (39, 25, 90), + 'RGB_MONSTER_W1_R4': (55, 55, 55), + 'RGB_MONSTER_W2_MAP': (90, 41, 20), + 'RGB_MONSTER_W2_R1': (92, 25, 25), + 'RGB_MONSTER_W2_R2': (55, 55, 55), + 'RGB_MONSTER_W2_R3': (67, 67, 14), + 'RGB_MONSTER_W2_R4': (90, 41, 20), } - -class SlowMonstersMod(JaxAtariInternalModPlugin): - """Decrease the movement speed of all monsters, including the hallway chaser.""" - +class GrayscaleMod(JaxAtariInternalModPlugin): + """Turns the entire game into grayscale.""" + name = "grayscale" + + # Using luminosity method (0.3R + 0.59G + 0.11B) constants_overrides = { - "MONSTER_SPEEDS": jnp.array([0.5, 0.75, 1.0, 1.25], dtype=jnp.float32), - "CHASER_SPEED": 0.2, + 'RGB_BACKGROUND': (0, 0, 0), + 'RGB_PLAYER_DETAILED': (65, 65, 65), + 'RGB_W1_WALLS': (94, 94, 94), + 'RGB_W2_WALLS': (84, 84, 84), + 'RGB_MONSTER_W1_MAP': (103, 103, 103), + 'RGB_MONSTER_W1_R2': (103, 103, 103), + 'RGB_MONSTER_W1_R3': (72, 72, 72), + 'RGB_MONSTER_W1_R4': (111, 111, 111), + 'RGB_MONSTER_W2_MAP': (107, 107, 107), + 'RGB_MONSTER_W2_R1': (90, 90, 90), + 'RGB_MONSTER_W2_R2': (111, 111, 111), + 'RGB_MONSTER_W2_R3': (119, 119, 119), + 'RGB_MONSTER_W2_R4': (107, 107, 107), } - -class WealthyVentureMod(JaxAtariInternalModPlugin): - """Significantly increase the points awarded for collecting treasures.""" - +class InvertedColorsMod(JaxAtariInternalModPlugin): + """Inverts all colors in the game.""" + name = "inverted_colors" + constants_overrides = { - "CHEST_SCORE": 1000, + 'RGB_BACKGROUND': (255, 255, 255), + 'RGB_PLAYER_DETAILED': (88, 229, 229), + 'RGB_W1_WALLS': (87, 207, 112), + 'RGB_W2_WALLS': (210, 168, 79), + 'RGB_MONSTER_W1_MAP': (173, 129, 210), + 'RGB_MONSTER_W1_R2': (173, 129, 210), + 'RGB_MONSTER_W1_R3': (177, 205, 74), + 'RGB_MONSTER_W1_R4': (144, 144, 144), + 'RGB_MONSTER_W2_MAP': (74, 172, 215), + 'RGB_MONSTER_W2_R1': (71, 205, 205), + 'RGB_MONSTER_W2_R2': (144, 144, 144), + 'RGB_MONSTER_W2_R3': (121, 121, 226), + 'RGB_MONSTER_W2_R4': (74, 172, 215), } - -class PatientChaserMod(JaxAtariInternalModPlugin): - """Increase the time before the hallway chaser appears in a room.""" - +class MatrixMod(JaxAtariInternalModPlugin): + """A Matrix-themed mod: black background, green walls, green monsters, white player.""" + name = "matrix_theme" + constants_overrides = { - "CHASER_SPAWN_FRAMES": 10000, + 'RGB_BACKGROUND': (0, 0, 0), + 'RGB_PLAYER_DETAILED': (255, 255, 255), + 'RGB_W1_WALLS': (0, 200, 0), + 'RGB_W2_WALLS': (0, 150, 0), + 'RGB_MONSTER_W1_MAP': (50, 255, 50), + 'RGB_MONSTER_W1_R2': (50, 255, 50), + 'RGB_MONSTER_W1_R3': (0, 255, 100), + 'RGB_MONSTER_W1_R4': (0, 180, 0), + 'RGB_MONSTER_W2_MAP': (50, 255, 50), + 'RGB_MONSTER_W2_R1': (0, 255, 100), + 'RGB_MONSTER_W2_R2': (0, 180, 0), + 'RGB_MONSTER_W2_R3': (100, 255, 100), + 'RGB_MONSTER_W2_R4': (50, 255, 50), } - -class FastArrowsMod(JaxAtariInternalModPlugin): - """Increase the speed of Winky's arrows.""" - +class BloodMoonMod(JaxAtariInternalModPlugin): + """A dark red themed mod.""" + name = "blood_moon" + constants_overrides = { - "PROJECTILE_SPEED": 4.0, - } - - -class LongRangeArrowsMod(JaxAtariInternalModPlugin): - """Increase the distance arrows travel before disappearing.""" - - constants_overrides = { - "PROJECTILE_LIFETIME_FRAMES": 60, - } - - -class GodModeMod(JaxAtariInternalModPlugin): - """Winky is immune to collisions with hazards.""" - - def _check_player_hazard_collision(self, player_state, monster_state, chaser_state, laser_state, current_level, world_level): - return jnp.array(False) + 'RGB_BACKGROUND': (40, 0, 0), + 'RGB_PLAYER_DETAILED': (255, 255, 255), + 'RGB_W1_WALLS': (180, 0, 0), + 'RGB_W2_WALLS': (150, 0, 0), + 'RGB_MONSTER_W1_MAP': (255, 100, 100), + 'RGB_MONSTER_W1_R2': (255, 100, 100), + 'RGB_MONSTER_W1_R3': (255, 50, 50), + 'RGB_MONSTER_W1_R4': (200, 50, 50), + 'RGB_MONSTER_W2_MAP': (255, 100, 100), + 'RGB_MONSTER_W2_R1': (255, 50, 50), + 'RGB_MONSTER_W2_R2': (200, 50, 50), + 'RGB_MONSTER_W2_R3': (255, 150, 150), + 'RGB_MONSTER_W2_R4': (255, 100, 100), + } \ No newline at end of file diff --git a/src/jaxatari/games/mods/venture_mods.py b/src/jaxatari/games/mods/venture_mods.py index 668acadfa..f46122361 100644 --- a/src/jaxatari/games/mods/venture_mods.py +++ b/src/jaxatari/games/mods/venture_mods.py @@ -1,37 +1,35 @@ -from jaxatari.modification import JaxAtariModController +from jaxatari.modification import JaxAtariModController, JaxAtariModWrapper +from jaxatari.games.jax_venture import JaxVenture, VentureConstants from jaxatari.games.mods.venture.venture_mod_plugins import ( - FastWinkyMod, - SlowMonstersMod, - WealthyVentureMod, - PatientChaserMod, - FastArrowsMod, - LongRangeArrowsMod, - GodModeMod, + NightMod, + GrayscaleMod, + InvertedColorsMod, + MatrixMod, + BloodMoonMod, ) - class VentureEnvMod(JaxAtariModController): - """Game-specific Mod Controller for Venture.""" - + """ + Controller for Venture mods. + """ + REGISTRY = { - "fast_winky": FastWinkyMod, - "slow_monsters": SlowMonstersMod, - "wealthy_venture": WealthyVentureMod, - "patient_chaser": PatientChaserMod, - "fast_arrows": FastArrowsMod, - "long_range_arrows": LongRangeArrowsMod, - "god_mode": GodModeMod, + 'night_mode': NightMod, + 'grayscale': GrayscaleMod, + 'inverted_colors': InvertedColorsMod, + 'matrix_theme': MatrixMod, + 'blood_moon': BloodMoonMod, } - def __init__( - self, - env, - mods_config: list = [], - allow_conflicts: bool = False, - ): + def __init__(self, + env, + mods_config: list = [], + allow_conflicts: bool = False + ): + super().__init__( env=env, mods_config=mods_config, allow_conflicts=allow_conflicts, - registry=self.REGISTRY, + registry=self.REGISTRY ) From 7bf44d112b4c9dc570f4129391aefca3e73792f7 Mon Sep 17 00:00:00 2001 From: Quentin Delfosse Date: Sun, 3 May 2026 23:27:09 +0200 Subject: [PATCH 12/28] Phoenix new mods --- src/jaxatari/games/jax_phoenix.py | 82 ++++++++++++++----- .../games/mods/phoenix/phoenix_mod_plugins.py | 79 ++++++++++++++++++ src/jaxatari/games/mods/phoenix_mods.py | 10 +++ 3 files changed, 149 insertions(+), 22 deletions(-) diff --git a/src/jaxatari/games/jax_phoenix.py b/src/jaxatari/games/jax_phoenix.py index db16f6abf..802ae6de3 100644 --- a/src/jaxatari/games/jax_phoenix.py +++ b/src/jaxatari/games/jax_phoenix.py @@ -1,7 +1,7 @@ import math import os from functools import partial -from typing import Tuple, NamedTuple +from typing import Tuple, NamedTuple, Optional import jax import jax.numpy as jnp import chex @@ -279,6 +279,13 @@ class PhoenixConstants(AutoDerivedConstants): pytree_node=False, default_factory=lambda: (135, 183, 84) ) + # Visual Mod Overrides + RGB_BACKGROUND: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_FLOOR: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_PHOENIX_MAIN: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_BATS_BLUE: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_BATS_RED: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + # --- Enemy spawn grids (5 formations x 8 slots). X: playfield x (unused slot = -1). Y: row height; 230 = below playfield (inactive). --- ENEMY_POSITIONS_X: jnp.ndarray = struct.field(pytree_node=False, default_factory=lambda: jnp.array([ [66, 90, 53, 104, 53, 104, 66, 90], @@ -2423,7 +2430,35 @@ def make_block(w, h, color): {"name": "boss_block_green", "type": "group", "data": [green_block]}, ] ) + + has_recolorings = False + for i in range(len(patched_config)): + asset_name = patched_config[i]['name'] + asset_rules = [] + if asset_name == 'background' and self.consts.RGB_BACKGROUND is not None: + asset_rules.append({'source': (0, 0, 0), 'target': self.consts.RGB_BACKGROUND}) + elif asset_name == 'floor' and self.consts.RGB_FLOOR is not None: + asset_rules.append({'source': (146, 70, 192), 'target': self.consts.RGB_FLOOR}) + elif asset_name in ('player', 'player_ability', 'player_projectile', 'digits', 'life_indicator'): + if self.consts.PLAYER_COLOR != (213, 130, 74): + asset_rules.append({'source': (213, 130, 74), 'target': self.consts.PLAYER_COLOR}) + elif asset_name == 'phoenix': + if self.consts.RGB_PHOENIX_MAIN is not None: + # Target the main body and wings + asset_rules.append({'source': (125, 48, 173), 'target': self.consts.RGB_PHOENIX_MAIN}) + asset_rules.append({'source': (227, 151, 89), 'target': self.consts.RGB_PHOENIX_MAIN}) + elif 'bat_blue' in asset_name and self.consts.RGB_BATS_BLUE is not None: + asset_rules.append({'target': self.consts.RGB_BATS_BLUE}) + elif 'bat_red' in asset_name and self.consts.RGB_BATS_RED is not None: + asset_rules.append({'target': self.consts.RGB_BATS_RED}) + + if asset_rules: + patched_config[i] = dict(patched_config[i]) + patched_config[i]['recolorings'] = {'mods': asset_rules} + has_recolorings = True + final_asset_config = patched_config + self._mask_suffix = '_mods' if has_recolorings else '' # 4. Load all assets, create palette, and generate ID masks in one call ( @@ -2456,6 +2491,9 @@ def _draw_rect_outline(self, raster, x, y, w, h, color_id): cid = jnp.asarray(color_id, dtype=raster.dtype) return jnp.where(outline, cid, raster) + def get_mask(self, key): + return self.SHAPE_MASKS.get(key + self._mask_suffix, self.SHAPE_MASKS[key]) + @partial(jax.jit, static_argnums=(0,)) def render(self, state): # Start with the background raster @@ -2486,7 +2524,7 @@ def render(self, state): @partial(jax.jit, static_argnums=(0,)) def _render_common(self, state, raster): - raster = self.jr.render_at(raster, 0, self.consts.FLOOR_Y, self.SHAPE_MASKS['floor']) + raster = self.jr.render_at(raster, 0, self.consts.FLOOR_Y, self.get_mask('floor')) player_death_sprite_duration = self.consts.PLAYER_DEATH_DURATION // 3 death_idx = jax.lax.select( @@ -2501,7 +2539,7 @@ def _render_common(self, state, raster): jax.lax.select(state.player_moving & anim_toggle, 4, 0) ) player_frame_index = jax.lax.select(state.player_dying, death_idx, alive_idx) - player_mask = self.SHAPE_MASKS["player"][player_frame_index] + player_mask = self.get_mask("player")[player_frame_index] player_flip_offset = self.FLIP_OFFSETS["player"] def draw_player(r): @@ -2513,9 +2551,9 @@ def draw_player(r): ) # Player projectile: don't render if it's inside the player sprite - projectile_mask = self.SHAPE_MASKS["player_projectile"] + projectile_mask = self.get_mask("player_projectile") proj_h, proj_w = projectile_mask.shape - player_mask_local = self.SHAPE_MASKS["player"][player_frame_index] + player_mask_local = self.get_mask("player")[player_frame_index] ph, pw = player_mask_local.shape overlap_x = (state.projectile_x + proj_w > state.player_x) & (state.projectile_x < state.player_x + pw) @@ -2533,8 +2571,8 @@ def render_player_projectile(r): ) def render_ability(r): - ability_mask = self.SHAPE_MASKS['player_ability'] - player_mask_local = self.SHAPE_MASKS["player"][player_frame_index] + ability_mask = self.get_mask('player_ability') + player_mask_local = self.get_mask("player")[player_frame_index] ah, aw = ability_mask.shape ph, pw = player_mask_local.shape ax = state.player_x + (pw - aw) // 2 @@ -2548,7 +2586,7 @@ def render_enemy_projectile(i, current_raster): x, y = state.enemy_projectile_x[i], state.enemy_projectile_y[i] return jax.lax.cond( y > -1, - lambda r: self.jr.render_at(r, x, y, self.SHAPE_MASKS['enemy_projectile']), + lambda r: self.jr.render_at(r, x, y, self.get_mask('enemy_projectile')), lambda r: r, current_raster ) @@ -2575,7 +2613,7 @@ def _render_phoenix_level(self, state, raster, is_level_two: bool): phoenix_death_phase = (state.phoenix_death_timer <= self.consts.ENEMY_DEATH_DURATION // 2).astype(jnp.int32) anim_toggle = ((state.step_counter // self.consts.ENEMY_ANIMATION_SPEED) % 2) == 0 phoenix_flip_offset = self.FLIP_OFFSETS['phoenix'] - green_enemy_mask = self.SHAPE_MASKS['green_enemy'] + green_enemy_mask = self.get_mask('green_enemy') def render_single_phoenix(i, current_raster): x, y = state.enemies_x[i], state.enemies_y[i] @@ -2585,7 +2623,7 @@ def draw_enemy(r): death_idx = jax.lax.select(phoenix_death_phase[i] == 0, 3, 4) alive_idx = jax.lax.select(is_moving_vert[i], 2, jax.lax.select(anim_toggle, 0, 1)) frame_idx = jax.lax.select(phoenix_death_flags[i], death_idx, alive_idx) - phoenix_mask = self.SHAPE_MASKS['phoenix'][frame_idx] + phoenix_mask = self.get_mask('phoenix')[frame_idx] use_green_enemy = is_level_two & (~phoenix_death_flags[i]) & (~is_moving_vert[i]) return jax.lax.cond( use_green_enemy, @@ -2601,13 +2639,13 @@ def draw_enemy(r): @partial(jax.jit, static_argnums=(0, 3)) def _render_bat_level(self, state, raster, is_blue_level: bool): bat_death_seg = jnp.maximum(1, self.consts.ENEMY_DEATH_DURATION // 3) - body_masks = self.SHAPE_MASKS['bat_blue_body'] if is_blue_level else self.SHAPE_MASKS['bat_red_body'] + body_masks = self.get_mask('bat_blue_body') if is_blue_level else self.get_mask('bat_red_body') body_offsets = self.FLIP_OFFSETS['bat_blue_body'] if is_blue_level else self.FLIP_OFFSETS['bat_red_body'] composite_name = "bat_blue_composite_anim" if is_blue_level else "bat_red_composite_anim" - has_composite = composite_name in self.SHAPE_MASKS - composite_masks = self.SHAPE_MASKS.get(composite_name, None) - composite_offsets = self.FLIP_OFFSETS.get(composite_name, jnp.array([0, 0], dtype=jnp.int32)) - wing_masks = self.SHAPE_MASKS['bat_blue_wings'] if is_blue_level else self.SHAPE_MASKS['bat_red_wings'] + composite_masks = self.get_mask(composite_name) if composite_name in self.SHAPE_MASKS else None + has_composite = composite_masks is not None + composite_offsets = self.FLIP_OFFSETS.get(composite_name + self._mask_suffix, self.FLIP_OFFSETS.get(composite_name, jnp.array([0, 0], dtype=jnp.int32))) + wing_masks = self.get_mask('bat_blue_wings') if is_blue_level else self.get_mask('bat_red_wings') wing_offsets = self.FLIP_OFFSETS['bat_blue_wings'] if is_blue_level else self.FLIP_OFFSETS['bat_red_wings'] # 7-phase cycle, each phase lasts 8 frames: # middle -> down_2 -> down -> down_2 -> middle -> up -> middle @@ -2704,7 +2742,7 @@ def _render_boss_level(self, state, raster): c = self.consts boss = state.boss - boss_mask = self.SHAPE_MASKS["boss"] + boss_mask = self.get_mask("boss") boss_flip_offset = self.FLIP_OFFSETS["boss"] core_x = boss.x - (c.BOSS_CORE_WIDTH / 2.0) core_y = boss.y + c.BOSS_CORE_Y_OFFSET @@ -2726,7 +2764,7 @@ def _render_boss_level(self, state, raster): rel_y = self.jr._yy - by # --- Blue blocks: 2 rows × 20 cols, 4×2 px each --- - blue_mask = self.SHAPE_MASKS["boss_block_blue"][0] + blue_mask = self.get_mask("boss_block_blue")[0] bh_b, bw_b = int(blue_mask.shape[0]), int(blue_mask.shape[1]) # (2, 4) n_cols_b, n_rows_b = 20, 2 blue_x0 = c.BOSS_BLUE_X0 # Python int = -40 @@ -2741,7 +2779,7 @@ def _render_boss_level(self, state, raster): raster = jnp.where(in_blue, blue_mask[0, 0], raster) # --- Red blocks: pyramid (rows [20,18,16,14,12,8,4]), 4×3 px each --- - red_mask = self.SHAPE_MASKS["boss_block_red"][0] + red_mask = self.get_mask("boss_block_red")[0] bh_r, bw_r = int(red_mask.shape[0]), int(red_mask.shape[1]) # (3, 4) dy_r = c.BOSS_RED_DY0 # Python int = 4 n_per_row_r = jnp.array([20, 18, 16, 14, 12, 8, 4], dtype=jnp.int32) @@ -2760,7 +2798,7 @@ def _render_boss_level(self, state, raster): raster = jnp.where(in_red, red_mask[0, 0], raster) # --- Green blocks: 5 rows above boss, 4×3 px, center gap at rel_x in [-4,4) --- - green_mask = self.SHAPE_MASKS["boss_block_green"][0] + green_mask = self.get_mask("boss_block_green")[0] bh_g, bw_g = int(green_mask.shape[0]), int(green_mask.shape[1]) # (3, 4) dy_g = c.BOSS_GREEN_DY0 # Python int = -3 n_per_row_g = jnp.array([12, 10, 8, 6, 4], dtype=jnp.int32) @@ -2794,7 +2832,7 @@ def render_enemy_projectile(i, current_raster): visible = (y > -1) & (y >= min_visible_y) return jax.lax.cond( visible, - lambda r: self.jr.render_at_clipped(r, x, y, self.SHAPE_MASKS["enemy_projectile"]), + lambda r: self.jr.render_at_clipped(r, x, y, self.get_mask("enemy_projectile")), lambda r: r, current_raster, ) @@ -2812,7 +2850,7 @@ def _render_ui(self, state, raster): # HUD placement: score centered in WIDTH with optional pixel nudge. score_dx, score_dy = 3, -5 score_y = 10 + score_dy - digit_masks = self.SHAPE_MASKS['digits'] + digit_masks = self.get_mask('digits') digit_w = digit_masks[0].shape[1] score_digits = self.jr.int_to_digits(state.score, max_digits=max_digits) has_nonzero = jnp.any(score_digits != 0) @@ -2828,7 +2866,7 @@ def _render_ui(self, state, raster): start_index, num_to_render, spacing=spacing, max_digits_to_render=max_digits ) - life_mask = self.SHAPE_MASKS['life_indicator'] + life_mask = self.get_mask('life_indicator') life_w = life_mask.shape[1] life_spacing = 4 lives_dx, lives_dy = 3, -7 diff --git a/src/jaxatari/games/mods/phoenix/phoenix_mod_plugins.py b/src/jaxatari/games/mods/phoenix/phoenix_mod_plugins.py index 8b3cfe34f..7eaad0bed 100644 --- a/src/jaxatari/games/mods/phoenix/phoenix_mod_plugins.py +++ b/src/jaxatari/games/mods/phoenix/phoenix_mod_plugins.py @@ -57,3 +57,82 @@ class NoAbilityCooldownMod(JaxAtariInternalModPlugin): "ABILITY_COOLDOWN": 0, } +class NightMod(JaxAtariInternalModPlugin): + """Dims the entire screen by 50% for a night mode experience.""" + name = "night_mode" + constants_overrides = { + 'SCORE_COLOR': (105, 105, 32), + 'PLAYER_COLOR': (106, 65, 37), + 'BOSS_BLUE_COLOR': (42, 46, 107), + 'BOSS_RED_COLOR': (100, 36, 36), + 'BOSS_GREEN_COLOR': (42, 80, 30), + 'RGB_BACKGROUND': (0, 0, 0), + 'RGB_FLOOR': (73, 35, 96), + 'RGB_PHOENIX_MAIN': (62, 24, 86), + 'RGB_BATS_BLUE': (66, 72, 126), + 'RGB_BATS_RED': (83, 13, 13), + } + +class GrayscaleMod(JaxAtariInternalModPlugin): + """Turns the entire game into grayscale.""" + name = "grayscale" + constants_overrides = { + 'SCORE_COLOR': (170, 170, 170), + 'PLAYER_COLOR': (120, 120, 120), + 'BOSS_BLUE_COLOR': (90, 90, 90), + 'BOSS_RED_COLOR': (100, 100, 100), + 'BOSS_GREEN_COLOR': (110, 110, 110), + 'RGB_BACKGROUND': (0, 0, 0), + 'RGB_FLOOR': (100, 100, 100), + 'RGB_PHOENIX_MAIN': (80, 80, 80), + 'RGB_BATS_BLUE': (120, 120, 120), + 'RGB_BATS_RED': (70, 70, 70), + } + +class InvertedColorsMod(JaxAtariInternalModPlugin): + """Inverts all colors in the game.""" + name = "inverted_colors" + constants_overrides = { + 'SCORE_COLOR': (45, 45, 191), + 'PLAYER_COLOR': (42, 125, 181), + 'BOSS_BLUE_COLOR': (171, 163, 41), + 'BOSS_RED_COLOR': (55, 183, 183), + 'BOSS_GREEN_COLOR': (171, 95, 195), + 'RGB_BACKGROUND': (255, 255, 255), + 'RGB_FLOOR': (109, 185, 63), + 'RGB_PHOENIX_MAIN': (130, 207, 82), + 'RGB_BATS_BLUE': (123, 111, 3), + 'RGB_BATS_RED': (88, 229, 229), + } + +class MatrixMod(JaxAtariInternalModPlugin): + """A Matrix-themed mod: black background, green elements.""" + name = "matrix_theme" + constants_overrides = { + 'SCORE_COLOR': (0, 255, 0), + 'PLAYER_COLOR': (255, 255, 255), + 'BOSS_BLUE_COLOR': (0, 150, 0), + 'BOSS_RED_COLOR': (0, 200, 0), + 'BOSS_GREEN_COLOR': (50, 255, 50), + 'RGB_BACKGROUND': (0, 0, 0), + 'RGB_FLOOR': (0, 180, 0), + 'RGB_PHOENIX_MAIN': (0, 255, 100), + 'RGB_BATS_BLUE': (0, 255, 0), + 'RGB_BATS_RED': (50, 255, 50), + } + +class BloodMoonMod(JaxAtariInternalModPlugin): + """A dark red themed mod.""" + name = "blood_moon" + constants_overrides = { + 'SCORE_COLOR': (255, 100, 100), + 'PLAYER_COLOR': (255, 255, 255), + 'BOSS_BLUE_COLOR': (150, 0, 0), + 'BOSS_RED_COLOR': (200, 50, 50), + 'BOSS_GREEN_COLOR': (180, 0, 0), + 'RGB_BACKGROUND': (40, 0, 0), + 'RGB_FLOOR': (200, 0, 0), + 'RGB_PHOENIX_MAIN': (255, 50, 50), + 'RGB_BATS_BLUE': (150, 0, 0), + 'RGB_BATS_RED': (255, 100, 100), + } diff --git a/src/jaxatari/games/mods/phoenix_mods.py b/src/jaxatari/games/mods/phoenix_mods.py index de588c7cd..d986f0777 100644 --- a/src/jaxatari/games/mods/phoenix_mods.py +++ b/src/jaxatari/games/mods/phoenix_mods.py @@ -6,6 +6,11 @@ InvinciblePlayerMod, FastEnemyBulletsMod, NoAbilityCooldownMod, + NightMod, + GrayscaleMod, + InvertedColorsMod, + MatrixMod, + BloodMoonMod, ) @@ -21,6 +26,11 @@ class PhoenixEnvMod(JaxAtariModController): "invincible_player": InvinciblePlayerMod, "fast_enemy_bullets": FastEnemyBulletsMod, "no_ability_cooldown": NoAbilityCooldownMod, + "night_mode": NightMod, + "grayscale": GrayscaleMod, + "inverted_colors": InvertedColorsMod, + "matrix_theme": MatrixMod, + "blood_moon": BloodMoonMod, } def __init__( From a8ec83cc2c989051a40698a291cebead0b59687c Mon Sep 17 00:00:00 2001 From: Quentin Delfosse Date: Mon, 4 May 2026 15:59:14 +0200 Subject: [PATCH 13/28] Reward at gate for skiing --- src/jaxatari/games/mods/skiing/skiing_mod_plugins.py | 10 ++++++++++ src/jaxatari/games/mods/skiing_mods.py | 3 ++- 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/src/jaxatari/games/mods/skiing/skiing_mod_plugins.py b/src/jaxatari/games/mods/skiing/skiing_mod_plugins.py index 864080960..e08eba3f9 100644 --- a/src/jaxatari/games/mods/skiing/skiing_mod_plugins.py +++ b/src/jaxatari/games/mods/skiing/skiing_mod_plugins.py @@ -345,3 +345,13 @@ class GreenFlagsMod(JaxAtariInternalModPlugin): constants_overrides = { "green_flags": True } + + +class RewardAtGateMod(JaxAtariInternalModPlugin): + """ + Changes the reward function to give +1 reward for passing a gate, + instead of the original ALE reward (time penalty + massive end penalty). + """ + constants_overrides = { + "USE_ORIGINAL_ALE_REWARD": False, + } diff --git a/src/jaxatari/games/mods/skiing_mods.py b/src/jaxatari/games/mods/skiing_mods.py index 601bf8bd5..21e00409d 100644 --- a/src/jaxatari/games/mods/skiing_mods.py +++ b/src/jaxatari/games/mods/skiing_mods.py @@ -4,7 +4,7 @@ MoreTreesMod, MoreMogulsMod, DangerousMogulsMod, JumpToBreakMod, SpeedBurstMod, TreesEverywhereMod, HallOfFameMod, InvertFlagsMod, InvertFlagColorsMod, MovingFlagsMod, RandomFlagsMod, FlagFlurryMod, MogulsToTreesMod, - ClassicTreesMod, ThinMogulsMod, BlueSkiierMod, GreenFlagsMod + ClassicTreesMod, ThinMogulsMod, BlueSkiierMod, GreenFlagsMod, RewardAtGateMod ) class SkiingEnvMod(JaxAtariModController): @@ -34,6 +34,7 @@ class SkiingEnvMod(JaxAtariModController): "thin_moguls": ThinMogulsMod, "blue_skiier": BlueSkiierMod, "green_flags": GreenFlagsMod, + "reward_at_gate": RewardAtGateMod, "off_piste": ["_more_trees", "_trees_everywhere", "_more_moguls", "_dangerous_moguls"], "change_sprites": ["classic_trees", "thin_moguls", "blue_skiier", "green_flags"], } From 25770820ce2ef6820e5bd2e70b38fb3b8a6e7bb5 Mon Sep 17 00:00:00 2001 From: Quentin Delfosse Date: Mon, 4 May 2026 18:50:51 +0200 Subject: [PATCH 14/28] Bankheist mods --- src/jaxatari/games/jax_bankheist.py | 14 +- .../mods/bankheist/bankheist_mod_plugins.py | 130 +++++++++++++++++- src/jaxatari/games/mods/bankheist_mods.py | 12 ++ 3 files changed, 152 insertions(+), 4 deletions(-) diff --git a/src/jaxatari/games/jax_bankheist.py b/src/jaxatari/games/jax_bankheist.py index d1b7e4173..8608485ce 100644 --- a/src/jaxatari/games/jax_bankheist.py +++ b/src/jaxatari/games/jax_bankheist.py @@ -1417,7 +1417,7 @@ def __init__(self, consts: BankHeistConstants = None, config: render_utils.Rende sprite_path = self.consts.SPRITES_DIR final_asset_config = list(self.consts.ASSET_CONFIG) - city_asset = next((a for a in final_asset_config if a['name'] == 'cities'), None) + city_asset = next((a for a in final_asset_config if a.get('name') == 'cities'), None) if city_asset: final_asset_config.remove(city_asset) city_files = city_asset['files'] @@ -1603,7 +1603,17 @@ def render(self, state): masks.append(self.DYNAMITE_BATCH_MASK) # --- Score --- - score_digits = self.jr.int_to_digits(state.money, max_digits=4) + is_negative = state.money < 0 + abs_money = jnp.abs(state.money) + score_digits = self.jr.int_to_digits(abs_money, max_digits=4) + + minus_mask = jnp.full(self.BATCH_SHAPE, self.jr.TRANSPARENT_ID, dtype=jnp.uint8) + minus_mask = minus_mask.at[3, 1:5].set(jnp.uint8(self.black_color_id)) + + xs.append(jax.lax.select(is_negative, jnp.array(74), jnp.array(-100))) + ys.append(jnp.array(179)) + masks.append(minus_mask) + for i in range(4): xs.append(jnp.array(90 + i * 12)) ys.append(jnp.array(179)) diff --git a/src/jaxatari/games/mods/bankheist/bankheist_mod_plugins.py b/src/jaxatari/games/mods/bankheist/bankheist_mod_plugins.py index 87f72398d..953e993f7 100644 --- a/src/jaxatari/games/mods/bankheist/bankheist_mod_plugins.py +++ b/src/jaxatari/games/mods/bankheist/bankheist_mod_plugins.py @@ -1,3 +1,5 @@ +import os +import numpy as np import chex import jax import jax.numpy as jnp @@ -6,6 +8,38 @@ from jaxatari.modification import JaxAtariInternalModPlugin, JaxAtariPostStepModPlugin from jaxatari.games.jax_bankheist import JaxBankHeist, BankHeistState, Entity +from jaxatari.rendering.jax_rendering_utils import get_base_sprite_dir + + +def _recolor_bankheist_sprite(filename: str, original_rgb: tuple, new_rgb: tuple) -> np.ndarray: + """Load a bankheist sprite .npy and replace original_rgb with new_rgb (alpha preserved).""" + sprite_dir = os.path.join(get_base_sprite_dir(), "bankheist") + sprite_path = os.path.join(sprite_dir, filename) + sprite = np.load(sprite_path).copy() + original = np.array([*original_rgb, 255], dtype=np.uint8) + replacement = np.array([*new_rgb, 255], dtype=np.uint8) + mask = np.all(sprite == original, axis=-1) + sprite[mask] = replacement + return sprite + + +def _recolor_bankheist_road(filename: str, original_rgb: tuple, new_rgb: tuple) -> np.ndarray: + """Load a bankheist sprite .npy and replace original_rgb with new_rgb only in the maze road area.""" + sprite_dir = os.path.join(get_base_sprite_dir(), "bankheist") + sprite_path = os.path.join(sprite_dir, filename) + sprite = np.load(sprite_path).copy() + original = np.array([*original_rgb, 255], dtype=np.uint8) + replacement = np.array([*new_rgb, 255], dtype=np.uint8) + + mask = np.all(sprite == original, axis=-1) + + # Restrict to the maze Y-coordinates (Y=45 to Y=186 inclusive) + y_mask = np.zeros_like(mask) + y_mask[45:187, 12:148] = True + + final_mask = mask & y_mask + sprite[final_mask] = replacement + return sprite class RandomBankSpawnsMod(JaxAtariInternalModPlugin): @@ -142,7 +176,7 @@ def handle_bank_robbery(self, state: BankHeistState, bank_hit_index: chex.Array) return state.replace( bank_positions=new_banks, pending_police_spawns=new_pending_spawns, - pending_police_bank_indices=new_pending_bank_indices, + pending_police_bank_indices=new_pending_indices, pending_police_spawn_positions=new_pending_spawn_positions, pending_police_scores=new_pending_scores, bank_heists=new_bank_heists, @@ -539,4 +573,96 @@ class DoubleSpeedMod(JaxAtariInternalModPlugin): def player_move_step(self, state: BankHeistState) -> BankHeistState: state = JaxBankHeist.player_move_step(self._env, state) state = JaxBankHeist.player_move_step(self._env, state) - return state \ No newline at end of file + return state + + +class GreyRoadMod(JaxAtariInternalModPlugin): + """Modifies the road color to grey.""" + asset_overrides = { + "cities": None, # Prevent BankHeistRenderer from overwriting our manual overrides + "background": { + "name": "background", + "type": "background", + "data": _recolor_bankheist_road("map_1.npy", (0, 0, 0), (80, 80, 80)) + }, + "city_maps": { + "name": "city_maps", + "type": "group", + "data": [_recolor_bankheist_road(f"map_{i+1}.npy", (0, 0, 0), (80, 80, 80)) for i in range(8)] + } + } + + +class RedPoliceCarsMod(JaxAtariInternalModPlugin): + """Modifies the color of police cars to red.""" + asset_overrides = { + "police_side": { + "name": "police_side", + "type": "single", + "data": _recolor_bankheist_sprite("police_side.npy", (24, 26, 167), (200, 0, 0)) + }, + "police_front": { + "name": "police_front", + "type": "single", + "data": _recolor_bankheist_sprite("police_front.npy", (24, 26, 167), (200, 0, 0)) + } + } + + +class GoldenBanksMod(JaxAtariInternalModPlugin): + """Modifies the color of banks to golden.""" + asset_overrides = { + "bank": { + "name": "bank", + "type": "single", + "data": _recolor_bankheist_sprite("bank.npy", (142, 142, 142), (218, 165, 32)) + } + } + + +class BluePlayerMod(JaxAtariInternalModPlugin): + """Modifies the color of the player to light blue (cyan).""" + asset_overrides = { + "player_side": { + "name": "player_side", + "type": "single", + "data": _recolor_bankheist_sprite("player_side.npy", (162, 98, 33), (0, 255, 255)) + }, + "player_front": { + "name": "player_front", + "type": "single", + "data": _recolor_bankheist_sprite("player_front.npy", (162, 98, 33), (0, 255, 255)) + } + } + + +class DynamitePenaltyMod(JaxAtariInternalModPlugin): + """ + Modifies the reward for killing a police car with dynamite to a penalty of -500. + """ + constants_overrides = { + "POLICE_KILL_REWARD": (-500, -500, -500) + } + + +class FuelForBanksMod(JaxAtariInternalModPlugin): + """ + Augments the player's fuel when 3 banks have been robbed. + """ + @partial(jax.jit, static_argnums=(0,)) + def handle_bank_robbery(self, state: BankHeistState, bank_hit_index: chex.Array) -> BankHeistState: + # Call the original method to handle the robbery logic + state = JaxBankHeist.handle_bank_robbery(self._env, state, bank_hit_index) + + # We augment the fuel if total_banks_robbed is a multiple of 3 + # Ensure we don't refill exactly at 0 (though handle_bank_robbery already increased it) + is_multiple_of_3 = (state.total_banks_robbed % 3 == 0) & (state.total_banks_robbed > 0) + + # Add a quarter of a tank for every 3 banks + fuel_bonus = self._env.consts.FUEL_CAPACITY * 0.25 + new_fuel = jnp.where(is_multiple_of_3, + jnp.minimum(state.fuel + fuel_bonus, self._env.consts.FUEL_CAPACITY), + state.fuel) + + return state.replace(fuel=new_fuel) + diff --git a/src/jaxatari/games/mods/bankheist_mods.py b/src/jaxatari/games/mods/bankheist_mods.py index 280babdda..c83cf3969 100644 --- a/src/jaxatari/games/mods/bankheist_mods.py +++ b/src/jaxatari/games/mods/bankheist_mods.py @@ -10,6 +10,12 @@ RevisitCityMod, MovingBanksMod, DoubleSpeedMod, + GreyRoadMod, + RedPoliceCarsMod, + GoldenBanksMod, + BluePlayerMod, + DynamitePenaltyMod, + FuelForBanksMod, ) # --- The Registry --- @@ -22,6 +28,12 @@ "revisit_city": RevisitCityMod, "moving_banks": MovingBanksMod, "double_speed": DoubleSpeedMod, + "grey_road": GreyRoadMod, + "red_police_cars": RedPoliceCarsMod, + "golden_banks": GoldenBanksMod, + "blue_player": BluePlayerMod, + "dynamite_penalty": DynamitePenaltyMod, + "fuel_for_banks": FuelForBanksMod, } class BankHeistEnvMod(JaxAtariModController): From f9a2d02f1e217bd5e1b3ca1c7f492cea90c2dfe5 Mon Sep 17 00:00:00 2001 From: Quentin Delfosse Date: Mon, 4 May 2026 19:37:11 +0200 Subject: [PATCH 15/28] Mspacman matrix mod --- src/jaxatari/games/jax_mspacman.py | 94 ++++++++++++++++--- .../mods/mspacman/mspacman_mod_plugins.py | 21 +++++ src/jaxatari/games/mods/mspacman_mods.py | 7 +- src/jaxatari/games/mspacman_mazes.py | 11 ++- 4 files changed, 115 insertions(+), 18 deletions(-) diff --git a/src/jaxatari/games/jax_mspacman.py b/src/jaxatari/games/jax_mspacman.py index 1eb2a964f..97b654fdc 100644 --- a/src/jaxatari/games/jax_mspacman.py +++ b/src/jaxatari/games/jax_mspacman.py @@ -100,6 +100,21 @@ class MsPacmanConstants(struct.PyTreeNode): WALL_COLOR: chex.Array = struct.field(pytree_node=False, default_factory=lambda: jnp.array([228, 111, 111], dtype=jnp.uint8)) PELLET_COLOR: chex.Array = struct.field(pytree_node=False, default_factory=lambda: jnp.array([228, 111, 111], dtype=jnp.uint8)) PACMAN_COLOR: chex.Array = struct.field(pytree_node=False, default_factory=lambda: jnp.array([210, 164, 74, 255], dtype=jnp.uint8)) + + # MOD COLORS (Optional overrides) + RGB_BACKGROUND: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_PACMAN: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_WALLS: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_PATH: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_PELLETS: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_GHOST_BLINKY: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_GHOST_PINKY: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_GHOST_INKY: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_GHOST_SUE: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_GHOST_FRIGHTENED: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_GHOST_BLINKING: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_FRUIT: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_SCORE: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) # -------- Entity classes -------- @@ -980,9 +995,17 @@ def __init__(self, consts: MsPacmanConstants = None, config: render_utils.Render sprite_path = os.path.join(render_utils.get_base_sprite_dir(), sprite_dir_name) + # Effective colors + bg_color = self.consts.RGB_BACKGROUND or (0, 0, 0) + wall_color = self.consts.RGB_WALLS or tuple(map(int, self.consts.WALL_COLOR.tolist()[:3])) + path_color = self.consts.RGB_PATH or tuple(map(int, self.consts.PATH_COLOR.tolist()[:3])) + pacman_color = self.consts.RGB_PACMAN or tuple(map(int, self.consts.PACMAN_COLOR.tolist()[:3])) + pellet_color = self.consts.RGB_PELLETS or wall_color # Default to wall color if not specified + score_color = self.consts.RGB_SCORE or (255, 255, 255) # Default white for score + # Define asset config asset_config = [ - {'name': 'dummy_bg', 'type': 'background', 'data': jnp.zeros((210, 160, 4), dtype=jnp.uint8)}, + {'name': 'dummy_bg', 'type': 'background', 'data': jnp.zeros((210, 160, 4), dtype=jnp.uint8).at[:, :, :3].set(jnp.array(bg_color, dtype=jnp.uint8))}, {'name': 'pacman_oriented', 'type': 'group', 'data': self._build_pacman_oriented_group(sprite_path)}, {'name': 'ghosts', 'type': 'group', 'files': [ 'ghost_blinky.npy', 'ghost_pinky.npy', 'ghost_inky.npy', 'ghost_sue.npy', @@ -995,33 +1018,77 @@ def __init__(self, consts: MsPacmanConstants = None, config: render_utils.Render {'name': 'digits', 'type': 'digits', 'pattern': 'score_{}.npy'}, ] + # Apply recoloring rules if any overrides are present + has_recolorings = False + for i in range(len(asset_config)): + asset = asset_config[i] + asset_name = asset['name'] + rules = [] + + if asset_name == 'pacman_oriented': + if self.consts.RGB_PACMAN is not None: + rules.append({'target': pacman_color}) + elif asset_name == 'ghosts': + # Blinky (Red), Pinky (Pink), Inky (Cyan), Sue (Orange) + # Blue (Frightened), White (Blinking) + if self.consts.RGB_GHOST_BLINKY is not None: + rules.append({'source': (228, 111, 111), 'target': self.consts.RGB_GHOST_BLINKY}) + if self.consts.RGB_GHOST_PINKY is not None: + rules.append({'source': (228, 164, 228), 'target': self.consts.RGB_GHOST_PINKY}) + if self.consts.RGB_GHOST_INKY is not None: + rules.append({'source': (24, 164, 180), 'target': self.consts.RGB_GHOST_INKY}) + if self.consts.RGB_GHOST_SUE is not None: + rules.append({'source': (210, 164, 74), 'target': self.consts.RGB_GHOST_SUE}) + if self.consts.RGB_GHOST_FRIGHTENED is not None: + rules.append({'source': (66, 72, 200), 'target': self.consts.RGB_GHOST_FRIGHTENED}) + if self.consts.RGB_GHOST_BLINKING is not None: + rules.append({'source': (255, 255, 255), 'target': self.consts.RGB_GHOST_BLINKING}) + elif asset_name == 'fruit': + if self.consts.RGB_FRUIT is not None: + rules.append({'target': self.consts.RGB_FRUIT}) + elif asset_name == 'digits': + if self.consts.RGB_SCORE is not None: + rules.append({'target': self.consts.RGB_SCORE}) + + if rules: + asset_config[i] = dict(asset) + asset_config[i]['recolorings'] = {'mods': rules} + has_recolorings = True + # Include background colors in the palette (Path, Wall, and Black for UI padding) - bg_colors = jnp.stack([self.consts.PATH_COLOR, self.consts.WALL_COLOR, jnp.array([0, 0, 0], dtype=jnp.uint8)]) + bg_colors = jnp.stack([jnp.array(path_color, dtype=jnp.uint8), jnp.array(wall_color, dtype=jnp.uint8), jnp.array(bg_color, dtype=jnp.uint8)]) bg_colors = jnp.concatenate([bg_colors, jnp.full((3, 1), 255, dtype=jnp.uint8)], axis=1) asset_config.append({'name': 'bg_colors', 'type': 'procedural', 'data': bg_colors[:, None, :]}) (self.PALETTE, self.SHAPE_MASKS, _, self.COLOR_TO_ID, self.FLIP_OFFSETS) = \ self.jr.load_and_setup_assets(asset_config, sprite_path) - for color in ( - tuple(map(int, self.consts.PATH_COLOR.tolist())), - tuple(map(int, self.consts.WALL_COLOR.tolist())), - (0, 0, 0), - ): + for color in (path_color, wall_color, bg_color): self._ensure_palette_color(color) + self._mask_suffix = '_mods' if has_recolorings else '' + + def get_mask(key): + return self.SHAPE_MASKS.get(key + self._mask_suffix, self.SHAPE_MASKS[key]) + # Pacman mask group is loaded orientation-major: # 0: UP, 1: RIGHT, 2: LEFT, 3: DOWN, each with 4 animation frames. - pacman_group = self.SHAPE_MASKS['pacman_oriented'] + pacman_group = get_mask('pacman_oriented') self.PACMAN_MASKS = pacman_group.reshape(4, 4, pacman_group.shape[1], pacman_group.shape[2]) # Pre-calculate backgrounds for all 4 mazes - self.MAZE_BACKGROUNDS = self._create_all_backgrounds() + self.MAZE_BACKGROUNDS = self._create_all_backgrounds( + jnp.array(wall_color, dtype=jnp.uint8), + jnp.array(path_color, dtype=jnp.uint8) + ) + + self.wall_id = self._resolve_color_id(wall_color) + self.pellet_id = self._resolve_color_id(pellet_color) - def _create_all_backgrounds(self): + def _create_all_backgrounds(self, wall_color=None, path_color=None): bgs = [] for i in range(4): - bg = MsPacmanMaze.load_background(i) # Returns (W, H, 3) + bg = MsPacmanMaze.load_background(i, wall_color=wall_color, path_color=path_color) # Returns (W, H, 3) bg = jnp.transpose(bg, (1, 0, 2)) # Convert to (H, W, 3) if bg.shape[2] == 3: bg = jnp.concatenate([bg, jnp.full((*bg.shape[:2], 1), 255, dtype=jnp.uint8)], axis=2) @@ -1037,11 +1104,10 @@ def render(self, state: PacmanState): raster = self.jr.create_object_raster(background) # 1. Render Pellets - wall_id = self._resolve_color_id(tuple(map(int, self.consts.WALL_COLOR.tolist()))) - raster = self.render_pellets(raster, state.level.pellets, wall_id) + raster = self.render_pellets(raster, state.level.pellets, self.pellet_id) # 2. Power Pellets - raster = self.render_power_pellets(raster, state, wall_id) + raster = self.render_power_pellets(raster, state, self.pellet_id) # 3. Pacman orientation = act_to_dir(state.player.action) diff --git a/src/jaxatari/games/mods/mspacman/mspacman_mod_plugins.py b/src/jaxatari/games/mods/mspacman/mspacman_mod_plugins.py index 0d4120634..d3a51ef6e 100644 --- a/src/jaxatari/games/mods/mspacman/mspacman_mod_plugins.py +++ b/src/jaxatari/games/mods/mspacman/mspacman_mod_plugins.py @@ -341,3 +341,24 @@ def _randomize_ghosts(self, state): ) new_ghosts = ghosts._replace(modes=new_modes) return state.replace(ghosts=new_ghosts) + + +class MatrixMod(JaxAtariInternalModPlugin): + """A Matrix-themed mod: black background, green walls, green ghosts, white pacman.""" + name = "matrix_theme" + + constants_overrides = { + 'RGB_BACKGROUND': (0, 0, 0), + 'RGB_PACMAN': (255, 255, 255), + 'RGB_WALLS': (0, 200, 0), + 'RGB_PATH': (0, 0, 0), + 'RGB_PELLETS': (0, 255, 0), + 'RGB_GHOST_BLINKY': (50, 255, 50), + 'RGB_GHOST_PINKY': (0, 255, 100), + 'RGB_GHOST_INKY': (0, 180, 0), + 'RGB_GHOST_SUE': (100, 255, 100), + 'RGB_GHOST_FRIGHTENED': (0, 100, 0), + 'RGB_GHOST_BLINKING': (150, 255, 150), + 'RGB_FRUIT': (0, 255, 0), + 'RGB_SCORE': (0, 255, 0), + } diff --git a/src/jaxatari/games/mods/mspacman_mods.py b/src/jaxatari/games/mods/mspacman_mods.py index 337038809..f30749c8d 100644 --- a/src/jaxatari/games/mods/mspacman_mods.py +++ b/src/jaxatari/games/mods/mspacman_mods.py @@ -11,7 +11,8 @@ Only1GhostMod, Only2GhostMod, Only3GhostMod, - RandomGhostNavigationMod + RandomGhostNavigationMod, + MatrixMod ) class MsPacmanEnvMod(JaxAtariModController): @@ -30,6 +31,7 @@ class MsPacmanEnvMod(JaxAtariModController): "only_2_ghost": Only2GhostMod, "only_3_ghost": Only3GhostMod, "random_ghost_navigation": RandomGhostNavigationMod, + "matrix_theme": MatrixMod, } _mod_sprite_dir = os.path.join(os.path.dirname(__file__), "mspacman", "sprites") @@ -46,3 +48,6 @@ def __init__(self, allow_conflicts=allow_conflicts, registry=self.REGISTRY ) + +# Alias for utils.py loader which uses capitalize() +MspacmanEnvMod = MsPacmanEnvMod diff --git a/src/jaxatari/games/mspacman_mazes.py b/src/jaxatari/games/mspacman_mazes.py index e8a8fcd73..e69ce3e25 100644 --- a/src/jaxatari/games/mspacman_mazes.py +++ b/src/jaxatari/games/mspacman_mazes.py @@ -259,10 +259,15 @@ def precompute_dof(maze_id: int): return dof_grid @staticmethod - def load_background(maze_id: int): + def load_background(maze_id: int, wall_color: jnp.ndarray = None, path_color: jnp.ndarray = None): """ Constructs the background based on the level. """ + if wall_color is None: + wall_color = MsPacmanMaze.WALL_COLOR + if path_color is None: + path_color = MsPacmanMaze.PATH_COLOR + # 1. Expand the maze layout to pixel scale maze = MsPacmanMaze.MAZES[maze_id] # (height, width) maze_expanded = jnp.repeat(jnp.repeat(maze, MsPacmanMaze.TILE_SCALE, axis=0), MsPacmanMaze.TILE_SCALE, axis=1) # (height*scale, width*scale) @@ -270,8 +275,8 @@ def load_background(maze_id: int): # 2. Assign color to each pixel background = jnp.where( maze_expanded[..., None], # (height*scale, width*scale, 1) - MsPacmanMaze.WALL_COLOR, # (3,) - MsPacmanMaze.PATH_COLOR # (3,) + wall_color, # (3,) + path_color # (3,) ) # 3. Pad to a height of 210 to accomodate UI From df9f1a4379cbce78b21c8dfdcc1b2581273669f6 Mon Sep 17 00:00:00 2001 From: Quentin Delfosse Date: Tue, 5 May 2026 10:28:10 +0200 Subject: [PATCH 16/28] Alien matrix --- src/jaxatari/core.py | 3 +- src/jaxatari/games/jax_alien.py | 121 ++++++++++++------ .../games/mods/alien/alien_mod_plugins.py | 14 ++ src/jaxatari/games/mods/alien_mods.py | 3 +- 4 files changed, 103 insertions(+), 38 deletions(-) diff --git a/src/jaxatari/core.py b/src/jaxatari/core.py index 303e53445..b4d944563 100644 --- a/src/jaxatari/core.py +++ b/src/jaxatari/core.py @@ -103,7 +103,8 @@ def _warn_deprecated_obs_to_flat_array(env: JaxEnvironment) -> None: "beamrider": "jaxatari.games.mods.beamrider_mods.BeamRiderEnvMod", "venture": "jaxatari.games.mods.venture_mods.VentureEnvMod", "spaceinvaders": "jaxatari.games.mods.spaceinvaders_mods.SpaceInvadersEnvMod", - "skiing": "jaxatari.games.mods.skiing_mods.SkiingEnvMod" + "skiing": "jaxatari.games.mods.skiing_mods.SkiingEnvMod", + "alien": "jaxatari.games.mods.alien_mods.AlienEnvMod" } diff --git a/src/jaxatari/games/jax_alien.py b/src/jaxatari/games/jax_alien.py index 05e923da5..49c97162c 100644 --- a/src/jaxatari/games/jax_alien.py +++ b/src/jaxatari/games/jax_alien.py @@ -1,7 +1,7 @@ import array import os from functools import partial -from typing import NamedTuple, Tuple, Any, Callable, Dict +from typing import NamedTuple, Tuple, Any, Callable, Dict, Optional import jax.numpy as jnp import chex from jaxatari.renderers import JAXGameRenderer @@ -26,7 +26,15 @@ 'FRIGHTENED': (101, 111, 228), # "Other Blue" / Killable Enemy } -def get_alien_asset_config(): +def get_alien_asset_config(consts: "AlienConstants" = None): + # Resolve colors + basic_blue = (consts.RGB_BASIC_BLUE or COLORS['BASIC_BLUE']) if consts else COLORS['BASIC_BLUE'] + orange = (consts.RGB_ORANGE or COLORS['ORANGE']) if consts else COLORS['ORANGE'] + pink = (consts.RGB_PINK or COLORS['PINK']) if consts else COLORS['PINK'] + green = (consts.RGB_GREEN or COLORS['GREEN']) if consts else COLORS['GREEN'] + yellow = (consts.RGB_YELLOW or COLORS['YELLOW']) if consts else COLORS['YELLOW'] + frightened = (consts.RGB_FRIGHTENED or COLORS['FRIGHTENED']) if consts else COLORS['FRIGHTENED'] + return [ # --- Backgrounds --- {'name': 'map_primary', 'type': 'background', 'file': 'bg/map_sprite.npy'}, @@ -41,8 +49,8 @@ def get_alien_asset_config(): 'player_animation/player3.npy', 'player_animation/player2.npy'], # Added 4th frame (ping-pong) to match teleport shape 'recolorings': { - 'normal': COLORS['BASIC_BLUE'], - 'flame': COLORS['ORANGE'] + 'normal': basic_blue, + 'flame': orange } }, { @@ -52,7 +60,7 @@ def get_alien_asset_config(): 'player_death_animation/player_death_2_sprite.npy', 'player_death_animation/player_death_3_sprite.npy', 'player_death_animation/player_death_4_sprite.npy'], - 'recolorings': {'normal': COLORS['BASIC_BLUE']} + 'recolorings': {'normal': basic_blue} }, { 'name': 'player_teleport', @@ -61,7 +69,7 @@ def get_alien_asset_config(): 'player_teleport_animation/teleport2.npy', 'player_teleport_animation/teleport3.npy', 'player_teleport_animation/teleport4.npy'], - 'recolorings': {'normal': COLORS['BASIC_BLUE']} + 'recolorings': {'normal': basic_blue} }, # --- Flame --- @@ -69,7 +77,7 @@ def get_alien_asset_config(): 'name': 'flame', 'type': 'single', 'file': 'flame/flame_sprite.npy', - 'recolorings': {'normal': COLORS['ORANGE']} + 'recolorings': {'normal': orange} }, # --- Enemies --- @@ -81,10 +89,10 @@ def get_alien_asset_config(): 'enemy_animation/enemy_walk3.npy', 'enemy_animation/enemy_walk2.npy'], # Added 4th frame (ping-pong) to match teleport shape 'recolorings': { - 'pink': COLORS['PINK'], - 'yellow': COLORS['YELLOW'], - 'green': COLORS['GREEN'], - 'frightened': COLORS['FRIGHTENED'] + 'pink': pink, + 'yellow': yellow, + 'green': green, + 'frightened': frightened } }, { @@ -95,10 +103,10 @@ def get_alien_asset_config(): 'enemy_teleport_animation/3.npy', 'enemy_teleport_animation/4.npy'], 'recolorings': { - 'pink': COLORS['PINK'], - 'yellow': COLORS['YELLOW'], - 'green': COLORS['GREEN'], - 'frightened': COLORS['FRIGHTENED'] + 'pink': pink, + 'yellow': yellow, + 'green': green, + 'frightened': frightened } }, { @@ -108,7 +116,7 @@ def get_alien_asset_config(): 'alien_death_animation/alien_death2.npy', 'alien_death_animation/alien_death3.npy', 'alien_death_animation/alien_death4.npy'], - 'recolorings': {'normal': COLORS['FRIGHTENED']} + 'recolorings': {'normal': frightened} }, # --- Items --- @@ -117,8 +125,8 @@ def get_alien_asset_config(): 'type': 'group', 'files': ['items/evil_item_1.npy', 'items/evil_item_2.npy'], 'recolorings': { - 'normal': COLORS['ORANGE'], - 'bonus_green': COLORS['GREEN'] + 'normal': orange, + 'bonus_green': green } }, { @@ -127,7 +135,7 @@ def get_alien_asset_config(): 'files': ['items/pulsar.npy', 'items/rocket.npy', 'items/saturn.npy', 'items/starship.npy', 'items/orb.npy', 'items/pi.npy'], - 'recolorings': {'normal': COLORS['YELLOW']} + 'recolorings': {'normal': yellow} }, # --- Eggs --- @@ -136,11 +144,11 @@ def get_alien_asset_config(): 'type': 'single', 'file': 'egg/egg.npy', 'recolorings': { - 'yellow': COLORS['YELLOW'], - 'orange': COLORS['ORANGE'], - 'blue': COLORS['BASIC_BLUE'], - 'pink': COLORS['PINK'], - 'green': COLORS['GREEN'] + 'yellow': yellow, + 'orange': orange, + 'blue': basic_blue, + 'pink': pink, + 'green': green } }, { @@ -148,11 +156,11 @@ def get_alien_asset_config(): 'type': 'single', 'file': 'egg/half_egg.npy', 'recolorings': { - 'yellow': COLORS['YELLOW'], - 'orange': COLORS['ORANGE'], - 'blue': COLORS['BASIC_BLUE'], - 'pink': COLORS['PINK'], - 'green': COLORS['GREEN'] + 'yellow': yellow, + 'orange': orange, + 'blue': basic_blue, + 'pink': pink, + 'green': green } }, @@ -161,13 +169,13 @@ def get_alien_asset_config(): 'name': 'digits', 'type': 'digits', 'pattern': 'digits/{}.npy', - 'recolorings': {'normal': COLORS['BASIC_BLUE']} + 'recolorings': {'normal': basic_blue} }, { 'name': 'life', 'type': 'single', 'file': 'life/life_sprite.npy', - 'recolorings': {'normal': COLORS['BASIC_BLUE']} + 'recolorings': {'normal': basic_blue} } ] @@ -236,6 +244,16 @@ class AlienConstants(struct.PyTreeNode): SCATTER_DURATION_1: int = struct.field(pytree_node=False, default=100) CHASE_DURATION_1: int = struct.field(pytree_node=False, default=200) MODECHANGE_PROBABILITY_1: float = struct.field(pytree_node=False, default=0.5) + + # MOD COLORS (Optional overrides) + RGB_BACKGROUND: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_BASIC_BLUE: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_ORANGE: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_PINK: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_GREEN: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_YELLOW: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_FRIGHTENED: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_MAP_PRIMARY: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) SCATTER_POINT_X_1: int = struct.field(pytree_node=False, default=0) SCATTER_POINT_Y_1: int = struct.field(pytree_node=False, default=0) @@ -2497,15 +2515,22 @@ def __init__(self, consts: AlienConstants = None, config: render_utils.RendererC # Pre-load background preprocessed_assets = self._load_and_preprocess_assets(sprite_path) - raw_config = get_alien_asset_config() + raw_config = get_alien_asset_config(self.consts) asset_config = [] for asset in raw_config: if asset['name'] == 'map_primary': asset_copy = asset.copy() - del asset_copy['file'] + if 'file' in asset_copy: + del asset_copy['file'] asset_copy['data'] = preprocessed_assets['map_primary'] asset_config.append(asset_copy) + elif asset['name'] == 'map_bonus': + asset_copy = asset.copy() + if 'file' in asset_copy: + del asset_copy['file'] + asset_copy['data'] = preprocessed_assets['map_bonus'] + asset_config.append(asset_copy) else: asset_config.append(asset) @@ -2522,12 +2547,35 @@ def __init__(self, consts: AlienConstants = None, config: render_utils.RendererC def _load_and_preprocess_assets(self, sprite_path: str) -> dict: target_shape = (self.consts.HEIGHT, self.consts.WIDTH, 4) - full_bg = jnp.zeros(target_shape, dtype=jnp.uint8) - full_bg = full_bg.at[:, :, 3].set(255) + bg_color = self.consts.RGB_BACKGROUND or (0, 0, 0) + full_bg = jnp.full(target_shape, jnp.array([*bg_color, 255], dtype=jnp.uint8)) map_path = os.path.join(sprite_path, "bg/map_sprite.npy") map_raw = self.jr.loadFrame(map_path) + bonus_map_path = os.path.join(sprite_path, "bg/bonus_map_sprite.npy") + bonus_map_raw = self.jr.loadFrame(bonus_map_path) + + # Original hardcoded colors in the Alien maps + orig_bg_color = jnp.array([45, 50, 184], dtype=jnp.uint8) + orig_wall_color = jnp.array([80, 0, 132], dtype=jnp.uint8) + + if self.consts.RGB_BACKGROUND is not None: + target_bg = jnp.array(self.consts.RGB_BACKGROUND, dtype=jnp.uint8) + mask_bg_primary = jnp.all(map_raw[..., :3] == orig_bg_color, axis=-1) + map_raw = map_raw.at[mask_bg_primary, :3].set(target_bg) + + mask_bg_bonus = jnp.all(bonus_map_raw[..., :3] == orig_bg_color, axis=-1) + bonus_map_raw = bonus_map_raw.at[mask_bg_bonus, :3].set(target_bg) + + if self.consts.RGB_BASIC_BLUE is not None: + target_wall = jnp.array(self.consts.RGB_BASIC_BLUE, dtype=jnp.uint8) + mask_wall_primary = jnp.all(map_raw[..., :3] == orig_wall_color, axis=-1) + map_raw = map_raw.at[mask_wall_primary, :3].set(target_wall) + + mask_wall_bonus = jnp.all(bonus_map_raw[..., :3] == orig_wall_color, axis=-1) + bonus_map_raw = bonus_map_raw.at[mask_wall_bonus, :3].set(target_wall) + # Placement: # Offsets are (Y=5, X=8) # We write map_raw into full_bg at [5:..., 8:...] @@ -2538,7 +2586,8 @@ def _load_and_preprocess_assets(self, sprite_path: str) -> dict: map_primary_full = full_bg.at[off_y:off_y+h, off_x:off_x+w, :].set(map_raw) return { - 'map_primary': map_primary_full + 'map_primary': map_primary_full, + 'map_bonus': bonus_map_raw } def _cache_sprite_stacks(self): diff --git a/src/jaxatari/games/mods/alien/alien_mod_plugins.py b/src/jaxatari/games/mods/alien/alien_mod_plugins.py index 80548999a..a8f8efae1 100644 --- a/src/jaxatari/games/mods/alien/alien_mod_plugins.py +++ b/src/jaxatari/games/mods/alien/alien_mod_plugins.py @@ -87,3 +87,17 @@ class EndGameMod(JaxAtariInternalModPlugin): constants_overrides = { "EGG_ARRAY": jnp.array(egg_positions, dtype=jnp.int32) } + +class MatrixMod(JaxAtariInternalModPlugin): + """A Matrix-themed mod: black background, green walls, green enemies, green eggs.""" + name = "matrix_theme" + + constants_overrides = { + 'RGB_BACKGROUND': (0, 0, 0), + 'RGB_BASIC_BLUE': (0, 200, 0), # Walls, Player, UI + 'RGB_ORANGE': (0, 255, 0), # Flame, etc + 'RGB_PINK': (50, 255, 50), + 'RGB_GREEN': (0, 255, 100), + 'RGB_YELLOW': (100, 255, 100), + 'RGB_FRIGHTENED': (0, 100, 0), + } diff --git a/src/jaxatari/games/mods/alien_mods.py b/src/jaxatari/games/mods/alien_mods.py index e108aae15..901f7af59 100644 --- a/src/jaxatari/games/mods/alien_mods.py +++ b/src/jaxatari/games/mods/alien_mods.py @@ -1,6 +1,6 @@ import os from jaxatari.modification import JaxAtariModController -from jaxatari.games.mods.alien.alien_mod_plugins import LastEggMod, EndGameMod +from jaxatari.games.mods.alien.alien_mod_plugins import LastEggMod, EndGameMod, MatrixMod class AlienEnvMod(JaxAtariModController): """ @@ -11,6 +11,7 @@ class AlienEnvMod(JaxAtariModController): REGISTRY = { "last_egg": LastEggMod, "end_game": EndGameMod, + "matrix_theme": MatrixMod, } _mod_sprite_dir = os.path.join(os.path.dirname(__file__), "alien", "sprites") From afefb201a2c22bf5f038fa49cca02238eb2c82da Mon Sep 17 00:00:00 2001 From: Quentin Delfosse Date: Tue, 5 May 2026 11:05:24 +0200 Subject: [PATCH 17/28] Other alien mods --- src/jaxatari/games/jax_alien.py | 23 +++++- .../games/mods/alien/alien_mod_plugins.py | 79 ++++++++++++++++++- src/jaxatari/games/mods/alien_mods.py | 11 ++- 3 files changed, 108 insertions(+), 5 deletions(-) diff --git a/src/jaxatari/games/jax_alien.py b/src/jaxatari/games/jax_alien.py index 49c97162c..d4a9f9cae 100644 --- a/src/jaxatari/games/jax_alien.py +++ b/src/jaxatari/games/jax_alien.py @@ -223,6 +223,7 @@ class AlienConstants(struct.PyTreeNode): LIFE_Y: int = struct.field(pytree_node=False, default=187) LIFE_OFFSET_X: int = struct.field(pytree_node=False, default=2) # Offset between life sprites LIFE_WIDTH: int = struct.field(pytree_node=False, default=5) + MAX_LIVES_RENDERED: int = struct.field(pytree_node=False, default=3) # Enemy_player_collision_offset ENEMY_PLAYER_COLLISION_OFFSET_Y_LOW: int = struct.field(pytree_node=False, default=4) @@ -2950,9 +2951,12 @@ def draw_egg(_r): def _render_hud(self, state: AlienState, raster): # 1. Score: right-aligned — starts as 1 digit on the right, expands left as score grows score_val = state.level.score.astype(jnp.int32) - score_digits = self.jr.int_to_digits(score_val, max_digits=6) + score_val = jnp.where(score_val > 32768, score_val - 65536, score_val) + + abs_score = jnp.abs(score_val) + score_digits = self.jr.int_to_digits(abs_score, max_digits=6) score_flat = score_digits.flatten() - n = jnp.maximum(score_val, 0) + n = jnp.maximum(abs_score, 0) num_digits = jnp.where( n > 0, jnp.ceil(jnp.log10(n.astype(jnp.float32) + 1.0)).astype(jnp.int32), @@ -2976,6 +2980,19 @@ def _render_hud(self, state: AlienState, raster): max_digits_to_render=6 ) + # Add negative sign if score < 0 + is_negative = jnp.squeeze(score_val < 0) + minus_color_id = jnp.max(self.DIGITS[0]) + minus_mask = jnp.zeros_like(self.DIGITS[0]) + minus_mask = minus_mask.at[3, 1:5].set(minus_color_id) + + raster = jax.lax.cond( + is_negative, + lambda r: self.jr.render_at(r, score_x - score_spacing + 2, self.consts.SCORE_Y + self.consts.RENDER_OFFSET_Y, minus_mask), + lambda r: r, + raster + ) + # 2. Lives (shifted up) raster = self.jr.render_indicator( raster, @@ -2984,7 +3001,7 @@ def _render_hud(self, state: AlienState, raster): jnp.squeeze(state.level.lifes), self.LIFE, spacing=self.consts.LIFE_WIDTH + self.consts.LIFE_OFFSET_X, - max_value=3 + max_value=self.consts.MAX_LIVES_RENDERED ) return raster diff --git a/src/jaxatari/games/mods/alien/alien_mod_plugins.py b/src/jaxatari/games/mods/alien/alien_mod_plugins.py index a8f8efae1..32afb0f6d 100644 --- a/src/jaxatari/games/mods/alien/alien_mod_plugins.py +++ b/src/jaxatari/games/mods/alien/alien_mod_plugins.py @@ -1,7 +1,6 @@ import jax import jax.numpy as jnp from functools import partial -from jaxatari.games.jax_pong import PongState from jaxatari.modification import JaxAtariInternalModPlugin, JaxAtariPostStepModPlugin import chex from jaxatari.environment import JAXAtariAction as Action @@ -101,3 +100,81 @@ class MatrixMod(JaxAtariInternalModPlugin): 'RGB_YELLOW': (100, 255, 100), 'RGB_FRIGHTENED': (0, 100, 0), } + +class PacifistMod(JaxAtariInternalModPlugin): + """Aliens are never frightened, forcing a pure evasion playstyle.""" + name = "pacifist_mode" + constants_overrides = { + "FRIGHTENED_DURATION": 0, + "FLAME_FRIGHTENED_DURATION": 0 + } + +class AggressiveSwarmMod(JaxAtariInternalModPlugin): + """Aliens spend almost all their time actively chasing the player.""" + name = "aggressive_swarm" + constants_overrides = { + "SCATTER_DURATION_1": 10, + "SCATTER_DURATION_2": 10, + "SCATTER_DURATION_3": 10, + "CHASE_DURATION_1": 1000, + "CHASE_DURATION_2": 1000, + "CHASE_DURATION_3": 1000, + } + +class DontKillMod(JaxAtariInternalModPlugin, JaxAtariPostStepModPlugin): + """Punishes for killing and shooting enemies.""" + name = "dont_kill" + constants_overrides = { + "EGG_SCORE_MULTIPLYER": 50, + "ENEMY_KILL_SCORE": jnp.array([-1000, -2000, -5000], dtype=jnp.int32) + } + + @partial(jax.jit, static_argnums=(0,)) + def run(self, prev_state, new_state): + # Punish for shooting (pulsar usage) + # Every frame the flame is active, we subtract points. + is_shooting = new_state.player.flame.flame_flag > 0 + shooting_punishment = jnp.where(is_shooting, 2, 0).astype(jnp.uint16) + + new_score = new_state.level.score - shooting_punishment + + return new_state.replace(level=new_state.level.replace(score=new_score)) + + @partial(jax.jit, static_argnums=(0,)) + def _get_reward(self, previous_state, state): + # Calculate signed difference with wrap-around handling for uint16 score + curr_score = state.level.score[0].astype(jnp.int32) + prev_score = previous_state.level.score[0].astype(jnp.int32) + + diff = curr_score - prev_score + # Handle uint16 wrap: if diff is very large positive, it's likely a negative change + # (e.g., 500 - 1000 = -500 -> 65036. 65036 - 500 = 64536 > 32768) + diff = jnp.where(diff > 32768, diff - 65536, diff) + diff = jnp.where(diff < -32768, diff + 65536, diff) + + return diff.astype(jnp.float32) + + @partial(jax.jit, static_argnums=(0,)) + def _get_env_reward(self, previous_state, state): + return self._get_reward(previous_state, state) + +class ShortCircuitMod(JaxAtariInternalModPlugin): + """Vulnerability window is drastically reduced.""" + name = "short_circuit" + constants_overrides = { + "FRIGHTENED_DURATION": 30, + "FLAME_FRIGHTENED_DURATION": 2 + } + +class ExtraLivesMod(JaxAtariPostStepModPlugin): + """Start with many extra lives.""" + name = "extra_lives" + + constants_overrides = { + "MAX_LIVES_RENDERED": 9 + } + + @partial(jax.jit, static_argnums=(0,)) + def after_reset(self, obs, state): + new_level = state.level.replace(lifes=jnp.array(9, dtype=jnp.int32)) + return obs, state.replace(level=new_level) diff --git a/src/jaxatari/games/mods/alien_mods.py b/src/jaxatari/games/mods/alien_mods.py index 901f7af59..4b7f26852 100644 --- a/src/jaxatari/games/mods/alien_mods.py +++ b/src/jaxatari/games/mods/alien_mods.py @@ -1,6 +1,10 @@ import os from jaxatari.modification import JaxAtariModController -from jaxatari.games.mods.alien.alien_mod_plugins import LastEggMod, EndGameMod, MatrixMod +from jaxatari.games.mods.alien.alien_mod_plugins import ( + LastEggMod, EndGameMod, MatrixMod, + PacifistMod, AggressiveSwarmMod, DontKillMod, + ShortCircuitMod, ExtraLivesMod +) class AlienEnvMod(JaxAtariModController): """ @@ -12,6 +16,11 @@ class AlienEnvMod(JaxAtariModController): "last_egg": LastEggMod, "end_game": EndGameMod, "matrix_theme": MatrixMod, + "pacifist_mode": PacifistMod, + "aggressive_swarm": AggressiveSwarmMod, + "dont_kill": DontKillMod, + "short_circuit": ShortCircuitMod, + "extra_lives": ExtraLivesMod, } _mod_sprite_dir = os.path.join(os.path.dirname(__file__), "alien", "sprites") From ce9705953da6190c499d1fcd92c8cac2999a0127 Mon Sep 17 00:00:00 2001 From: PaulSeitz Date: Tue, 5 May 2026 15:23:37 +0200 Subject: [PATCH 18/28] [Mods] fixed test failures and eased trigger for tripwire warning --- .../mods/bankheist/bankheist_mod_plugins.py | 2 +- .../games/mods/skiing/skiing_mod_plugins.py | 2 +- src/jaxatari/modification.py | 38 ++++++++++++++----- 3 files changed, 30 insertions(+), 12 deletions(-) diff --git a/src/jaxatari/games/mods/bankheist/bankheist_mod_plugins.py b/src/jaxatari/games/mods/bankheist/bankheist_mod_plugins.py index 953e993f7..e1d77f774 100644 --- a/src/jaxatari/games/mods/bankheist/bankheist_mod_plugins.py +++ b/src/jaxatari/games/mods/bankheist/bankheist_mod_plugins.py @@ -176,7 +176,7 @@ def handle_bank_robbery(self, state: BankHeistState, bank_hit_index: chex.Array) return state.replace( bank_positions=new_banks, pending_police_spawns=new_pending_spawns, - pending_police_bank_indices=new_pending_indices, + pending_police_bank_indices=new_pending_bank_indices, pending_police_spawn_positions=new_pending_spawn_positions, pending_police_scores=new_pending_scores, bank_heists=new_bank_heists, diff --git a/src/jaxatari/games/mods/skiing/skiing_mod_plugins.py b/src/jaxatari/games/mods/skiing/skiing_mod_plugins.py index c9e08b57d..36603d7f7 100644 --- a/src/jaxatari/games/mods/skiing/skiing_mod_plugins.py +++ b/src/jaxatari/games/mods/skiing/skiing_mod_plugins.py @@ -83,7 +83,7 @@ def _get_initial_flags_x(self) -> chex.Array: @partial(jax.jit, static_argnums=(0,)) def _get_new_flag_x(self, state, i: chex.Array) -> chex.Array: - return jnp.float32(60.0) + return jnp.full(i.shape, 60.0, dtype=jnp.float32) @partial(jax.jit, static_argnums=(0,)) def _get_initial_trees_x(self) -> chex.Array: diff --git a/src/jaxatari/modification.py b/src/jaxatari/modification.py index 97120ab44..64d38930f 100644 --- a/src/jaxatari/modification.py +++ b/src/jaxatari/modification.py @@ -2,6 +2,7 @@ import importlib import types from typing import Any, Dict +from contextlib import contextmanager import jax import chex import warnings @@ -53,6 +54,8 @@ def _mark_jit_mutation(core_env, reason: str) -> None: Mark that runtime behavior changed. Warn if this happened after tracing. """ core_env._jit_mutation_epoch = getattr(core_env, "_jit_mutation_epoch", 0) + 1 + if getattr(core_env, "_jit_tripwire_suppression_depth", 0) > 0: + return if not getattr(core_env, "_jit_tripwire_enabled", True): return compiled_targets = _targets_with_compiled_cache(core_env) @@ -69,6 +72,20 @@ def _mark_jit_mutation(core_env, reason: str) -> None: ) +@contextmanager +def _suspend_jit_tripwire(core_env): + """ + Temporarily suppress tripwire warnings for controlled setup-time mutations. + """ + core_env._jit_tripwire_suppression_depth = getattr(core_env, "_jit_tripwire_suppression_depth", 0) + 1 + try: + yield + finally: + core_env._jit_tripwire_suppression_depth = max( + 0, getattr(core_env, "_jit_tripwire_suppression_depth", 1) - 1 + ) + + def _clear_registered_jit_caches(core_env) -> None: """ Clear all registered jitted callables that may have captured old renderer state. @@ -991,16 +1008,17 @@ def expand_mods(mod_list, depth=0): else: base_env._mod_history["constant"].add(k) - modded_env = ControllerClass( - env=base_env, - mods_config=expanded_mods_config, - allow_conflicts=allow_conflicts - ) + with _suspend_jit_tripwire(base_env): + modded_env = ControllerClass( + env=base_env, + mods_config=expanded_mods_config, + allow_conflicts=allow_conflicts + ) - final_env = JaxAtariModWrapper( - env=modded_env, - mods_config=expanded_mods_config, - allow_conflicts=allow_conflicts - ) + final_env = JaxAtariModWrapper( + env=modded_env, + mods_config=expanded_mods_config, + allow_conflicts=allow_conflicts + ) return final_env \ No newline at end of file From dfcd139cb48a61c4db2a686a4ce68c61352224d2 Mon Sep 17 00:00:00 2001 From: Quentin Delfosse Date: Tue, 5 May 2026 16:53:46 +0200 Subject: [PATCH 19/28] Time Pilot mods --- src/jaxatari/core.py | 3 +- src/jaxatari/games/jax_timepilot.py | 134 ++++++++-- .../games/mods/alien/alien_mod_plugins.py | 11 +- .../mods/timepilot/timepilot_mod_plugins.py | 243 ++++++++++++++++++ src/jaxatari/games/mods/timepilot_mods.py | 38 +++ 5 files changed, 409 insertions(+), 20 deletions(-) create mode 100644 src/jaxatari/games/mods/timepilot/timepilot_mod_plugins.py create mode 100644 src/jaxatari/games/mods/timepilot_mods.py diff --git a/src/jaxatari/core.py b/src/jaxatari/core.py index b4d944563..2f9459393 100644 --- a/src/jaxatari/core.py +++ b/src/jaxatari/core.py @@ -104,7 +104,8 @@ def _warn_deprecated_obs_to_flat_array(env: JaxEnvironment) -> None: "venture": "jaxatari.games.mods.venture_mods.VentureEnvMod", "spaceinvaders": "jaxatari.games.mods.spaceinvaders_mods.SpaceInvadersEnvMod", "skiing": "jaxatari.games.mods.skiing_mods.SkiingEnvMod", - "alien": "jaxatari.games.mods.alien_mods.AlienEnvMod" + "alien": "jaxatari.games.mods.alien_mods.AlienEnvMod", + "timepilot": "jaxatari.games.mods.timepilot_mods.TimePilotEnvMod" } diff --git a/src/jaxatari/games/jax_timepilot.py b/src/jaxatari/games/jax_timepilot.py index c0c03bd62..17f0b5bd1 100644 --- a/src/jaxatari/games/jax_timepilot.py +++ b/src/jaxatari/games/jax_timepilot.py @@ -214,6 +214,18 @@ class TimePilotConstants(AutoDerivedConstants): LEVEL_4: LevelConstants = struct.field(pytree_node=False, default=TimePilot_Level_4) LEVEL_5: LevelConstants = struct.field(pytree_node=False, default=TimePilot_Level_5) + # MOD COLORS (Optional overrides) + RGB_BACKGROUND: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_PLAYER: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_ENEMY: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_BOSS: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_CLOUD: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_MISSILE: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_UI: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_SCORE: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_LIVES: Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + RGB_BOTTOM_WALL:Optional[Tuple[int, int, int]] = struct.field(pytree_node=False, default=None) + # Asset config baked into constants (immutable default) for asset overrides ASSET_CONFIG: tuple = struct.field(pytree_node=False, default=_get_default_asset_config()) @@ -1477,7 +1489,74 @@ def __init__(self, consts: TimePilotConstants|None = None, config: render_utils. self.jr = render_utils.JaxRenderingUtils(self.config) # 2. Start from (possibly modded) asset config provided via constants - final_asset_config = list(self.consts.ASSET_CONFIG) + asset_config = list(self.consts.ASSET_CONFIG) + + # Apply recoloring rules if any RGB overrides are present + has_recolorings = False + ui_source_colors = [ + (51, 26, 163), (84, 92, 214), (167, 26, 26), (195, 144, 61), + (214, 214, 214), (78, 50, 181), (142, 142, 142), (84, 138, 210), + (135, 183, 84), (168, 48, 143) + ] + for i in range(len(asset_config)): + asset = asset_config[i] + asset_name = asset['name'] + rules = [] + + if asset_name == 'all_player_sprites': + if self.consts.RGB_PLAYER is not None: + rules.append({'target': self.consts.RGB_PLAYER}) + elif asset_name == 'all_enemy_sprites': + if self.consts.RGB_ENEMY is not None: + rules.append({'target': self.consts.RGB_ENEMY}) + elif self.consts.RGB_BOSS is not None: + rules.append({'target': self.consts.RGB_BOSS}) + elif asset_name == 'all_clouds': + if self.consts.RGB_CLOUD is not None: + rules.append({'target': self.consts.RGB_CLOUD}) + elif asset_name in ('all_player_missiles', 'all_enemy_missiles'): + if self.consts.RGB_MISSILE is not None: + rules.append({'target': self.consts.RGB_MISSILE}) + elif asset_name == 'digits': + if self.consts.RGB_SCORE is not None: + rules.append({'target': self.consts.RGB_SCORE}) + elif self.consts.RGB_UI is not None: + rules.append({'target': self.consts.RGB_UI}) + elif asset_name == 'player_life': + if self.consts.RGB_LIVES is not None: + rules.append({'target': self.consts.RGB_LIVES}) + elif self.consts.RGB_UI is not None: + rules.append({'target': self.consts.RGB_UI}) + elif asset_name in ('transition_bar', 'all_enemy_remaining'): + if self.consts.RGB_UI is not None: + rules.append({'target': self.consts.RGB_UI}) + elif asset_name in ('bottom_wall', 'respawn_bottom_wall'): + target_color = self.consts.RGB_BOTTOM_WALL if self.consts.RGB_BOTTOM_WALL is not None else self.consts.RGB_UI + if target_color is not None: + for c in ui_source_colors: + rules.append({'source': c, 'target': target_color}) + if self.consts.RGB_BACKGROUND is not None: + rules.append({'source': (0, 0, 0), 'target': self.consts.RGB_BACKGROUND}) + elif asset_name in ('start_screen', 'top_wall', 'all_respawn_top_walls'): + if self.consts.RGB_UI is not None: + for c in ui_source_colors: + rules.append({'source': c, 'target': self.consts.RGB_UI}) + if self.consts.RGB_BACKGROUND is not None: + rules.append({'source': (0, 0, 0), 'target': self.consts.RGB_BACKGROUND}) + elif asset_name == 'all_backgrounds': + if self.consts.RGB_BACKGROUND is not None: + rules.append({'target': self.consts.RGB_BACKGROUND}) + elif asset_name == 'background': + if self.consts.RGB_BACKGROUND is not None: + # Background is procedural black in default config + asset_config[i] = dict(asset) + asset_config[i]['data'] = asset['data'].at[:, :, :3].set(jnp.array(self.consts.RGB_BACKGROUND, dtype=jnp.uint8)) + continue + + if rules: + asset_config[i] = dict(asset) + asset_config[i]['recolorings'] = {'mods': rules} + has_recolorings = True # 3. Load, process, and set up all assets in one call ( @@ -1486,7 +1565,9 @@ def __init__(self, consts: TimePilotConstants|None = None, config: render_utils. self.BACKGROUND, # This will be our empty (black) raster self.COLOR_TO_ID, self.FLIP_OFFSETS - ) = self.jr.load_and_setup_assets(final_asset_config, self.sprite_path) + ) = self.jr.load_and_setup_assets(asset_config, self.sprite_path) + + self._mask_suffix = '_mods' if has_recolorings else '' # 4. Get specific color IDs we'll need for procedural drawing self.BLACK_ID = self.COLOR_TO_ID.get((0, 0, 0), 0) @@ -1501,6 +1582,8 @@ def _post_process_sprites(self): Organizes the flat SHAPE_MASKS and FLIP_OFFSETS from setup into the nested list[dict] structure that the render() method expects. """ + def get_mask(key): + return self.SHAPE_MASKS.get(key + self._mask_suffix, self.SHAPE_MASKS[key]) # --- General Sprites --- self.general_sprites = {} @@ -1512,12 +1595,12 @@ def _post_process_sprites(self): 'player_life', 'black_line', 'transition_bar', 'digits' ] for key in simple_general_keys: - self.general_sprites[key] = self.SHAPE_MASKS[key] + self.general_sprites[key] = get_mask(key) self.general_offsets[key] = self.FLIP_OFFSETS[key] # Sliced mappings - self.general_sprites['transition_player_pos'] = self.SHAPE_MASKS['all_player_sprites'][50:58] - self.general_sprites['transition_player_death'] = self.SHAPE_MASKS['all_player_sprites'][58] + self.general_sprites['transition_player_pos'] = get_mask('all_player_sprites')[50:58] + self.general_sprites['transition_player_death'] = get_mask('all_player_sprites')[58] self.general_offsets['transition_player_pos'] = self.FLIP_OFFSETS['all_player_sprites'] self.general_offsets['transition_player_death'] = self.FLIP_OFFSETS['all_player_sprites'] @@ -1530,34 +1613,34 @@ def _post_process_sprites(self): level_dict_offsets = {} # Simple group slices - level_dict_masks['cloud'] = self.SHAPE_MASKS['all_clouds'][i] + level_dict_masks['cloud'] = get_mask('all_clouds')[i] level_dict_offsets['cloud'] = self.FLIP_OFFSETS['all_clouds'] - level_dict_masks['background'] = self.SHAPE_MASKS['all_backgrounds'][i] + level_dict_masks['background'] = get_mask('all_backgrounds')[i] level_dict_offsets['background'] = self.FLIP_OFFSETS['all_backgrounds'] - level_dict_masks['respawn_top_wall'] = self.SHAPE_MASKS['all_respawn_top_walls'][i] + level_dict_masks['respawn_top_wall'] = get_mask('all_respawn_top_walls')[i] level_dict_offsets['respawn_top_wall'] = self.FLIP_OFFSETS['all_respawn_top_walls'] - level_dict_masks['player_missile'] = self.SHAPE_MASKS['all_player_missiles'][i] + level_dict_masks['player_missile'] = get_mask('all_player_missiles')[i] level_dict_offsets['player_missile'] = self.FLIP_OFFSETS['all_player_missiles'] - level_dict_masks['enemy_missile'] = self.SHAPE_MASKS['all_enemy_missiles'][i] + level_dict_masks['enemy_missile'] = get_mask('all_enemy_missiles')[i] level_dict_offsets['enemy_missile'] = self.FLIP_OFFSETS['all_enemy_missiles'] # Complex group slices - level_dict_masks['player_pos'] = self.SHAPE_MASKS['all_player_sprites'][i*10 : i*10+8] - level_dict_masks['player_death'] = self.SHAPE_MASKS['all_player_sprites'][i*10+8 : (i+1)*10] + level_dict_masks['player_pos'] = get_mask('all_player_sprites')[i*10 : i*10+8] + level_dict_masks['player_death'] = get_mask('all_player_sprites')[i*10+8 : (i+1)*10] level_dict_offsets['player_pos'] = self.FLIP_OFFSETS['all_player_sprites'] level_dict_offsets['player_death'] = self.FLIP_OFFSETS['all_player_sprites'] - level_dict_masks['enemy_pos'] = self.SHAPE_MASKS['all_enemy_sprites'][i*15 : i*15+10] - level_dict_masks['enemy_death'] = self.SHAPE_MASKS['all_enemy_sprites'][i*15+10] + level_dict_masks['enemy_pos'] = get_mask('all_enemy_sprites')[i*15 : i*15+10] + level_dict_masks['enemy_death'] = get_mask('all_enemy_sprites')[i*15+10] level_dict_offsets['enemy_pos'] = self.FLIP_OFFSETS['all_enemy_sprites'] level_dict_offsets['enemy_death'] = self.FLIP_OFFSETS['all_enemy_sprites'] # Boss sprites (4 per level) - boss_sprites = self.SHAPE_MASKS['all_enemy_sprites'][i*15+11 : i*15+15] + boss_sprites = get_mask('all_enemy_sprites')[i*15+11 : i*15+15] ( level_dict_masks['level_boss_left_right'], level_dict_masks['level_boss_left_left'], @@ -1570,14 +1653,15 @@ def _post_process_sprites(self): level_dict_offsets['level_boss_right_right'] = self.FLIP_OFFSETS['all_enemy_sprites'] # Enemy remaining indicators (2 per level) - level_dict_masks['enemy_remaining'] = self.SHAPE_MASKS['all_enemy_remaining'][i*2] - level_dict_masks['enemy_remaining_brown'] = self.SHAPE_MASKS['all_enemy_remaining'][i*2+1] + level_dict_masks['enemy_remaining'] = get_mask('all_enemy_remaining')[i*2] + level_dict_masks['enemy_remaining_brown'] = get_mask('all_enemy_remaining')[i*2+1] level_dict_offsets['enemy_remaining'] = self.FLIP_OFFSETS['all_enemy_remaining'] level_dict_offsets['enemy_remaining_brown'] = self.FLIP_OFFSETS['all_enemy_remaining'] self.level_sprites.append(level_dict_masks) self.level_offsets.append(level_dict_offsets) + @partial(jax.jit, static_argnums=(0,)) def render(self, state: TimePilotState) -> chex.Array: """ @@ -1837,9 +1921,23 @@ def render_cloud(i, r): # --- 9. Render UI (Score, Lives) --- # Score digit_masks = general_sprites['digits'] - digits = self.jr.int_to_digits(state.score, max_digits = 6) + abs_score = jnp.abs(state.score) + digits = self.jr.int_to_digits(abs_score, max_digits = 6) raster = self.jr.render_label_selective(raster, 57, 7, digits, digit_masks, 0, 6, spacing = 8, max_digits_to_render=6) + # Add negative sign if score < 0 + is_negative = jnp.squeeze(state.score < 0) + minus_color_id = jnp.max(digit_masks[0]) + minus_mask = jnp.zeros_like(digit_masks[0]) + minus_mask = minus_mask.at[3, 1:5].set(minus_color_id) + + raster = jax.lax.cond( + is_negative, + lambda r: self.jr.render_at(r, 57 - 8 + 2, 7, minus_mask), + lambda r: r, + raster + ) + # Lives raster = self.jr.render_indicator( raster, 88, 18, state.lives - 1, diff --git a/src/jaxatari/games/mods/alien/alien_mod_plugins.py b/src/jaxatari/games/mods/alien/alien_mod_plugins.py index 32afb0f6d..ea832f816 100644 --- a/src/jaxatari/games/mods/alien/alien_mod_plugins.py +++ b/src/jaxatari/games/mods/alien/alien_mod_plugins.py @@ -8,6 +8,7 @@ # --- 1. Individual Mod Plugins --- class LastEggMod(JaxAtariInternalModPlugin): + """Only one egg is remaining.""" key = jax.random.key(-1) egg_selected = jax.random.choice(key, 106) egg_positions = [ @@ -47,6 +48,7 @@ class LastEggMod(JaxAtariInternalModPlugin): } class EndGameMod(JaxAtariInternalModPlugin): + """Only one area has remaining eggs.""" key = jax.random.key(-1) cluster_selected = jax.random.choice(key, 2, (2, )) cluster = list(range(38, 53)) if cluster_selected[1] else list(range(16)) @@ -136,7 +138,14 @@ def run(self, prev_state, new_state): is_shooting = new_state.player.flame.flame_flag > 0 shooting_punishment = jnp.where(is_shooting, 2, 0).astype(jnp.uint16) - new_score = new_state.level.score - shooting_punishment + # Reward 10 points every 100 frames + survival_reward = jnp.where( + jnp.logical_and(new_state.step_counter > 0, new_state.step_counter % 100 == 0), + 10, + 0 + ).astype(jnp.uint16) + + new_score = new_state.level.score - shooting_punishment + survival_reward return new_state.replace(level=new_state.level.replace(score=new_score)) diff --git a/src/jaxatari/games/mods/timepilot/timepilot_mod_plugins.py b/src/jaxatari/games/mods/timepilot/timepilot_mod_plugins.py new file mode 100644 index 000000000..7781da177 --- /dev/null +++ b/src/jaxatari/games/mods/timepilot/timepilot_mod_plugins.py @@ -0,0 +1,243 @@ +import jax +import jax.numpy as jnp +from functools import partial +from jaxatari.modification import JaxAtariInternalModPlugin, JaxAtariPostStepModPlugin +from jaxatari.environment import JAXAtariAction as Action + +from jaxatari.games.timepilot_levels import ( + TimePilot_Level_1, + TimePilot_Level_2, + TimePilot_Level_3, + TimePilot_Level_4, + TimePilot_Level_5 +) + +def get_ordered_asset_config(order): + """ + Returns a declarative asset manifest with levels in a custom order. + """ + all_player_sprites_files = [] + for i in order: + all_player_sprites_files.extend([ + *(f'L{i}/L{i}_Player_Pos{j}.npy' for j in range(8)), + f'L{i}/L{i}_Player_Death1.npy', f'L{i}/L{i}_Player_Death2.npy', + ]) + # Add transition sprites + all_player_sprites_files.extend([ + *(f'L-All/TP_Player_Pos{i}.npy' for i in range(8)), + 'L-All/TP_Player_Death.npy', + ]) + + all_enemy_sprites_files = [] + for i in order: + if i == 3: + all_enemy_sprites_files.extend([ + *(f'L3/L3_Enemy_Pos{j}.npy' for j in ["01", "02", "11", "12", "21", "22", "31", "32", "41", "42"]), + 'L3/L3_Enemy_Death.npy', + *(f'L3/L3_Boss_Pos{j}.npy' for j in ["01", "02", "11", "12"]), + ]) + elif i == 5: + all_enemy_sprites_files.extend([ + *(f'L5/L5_Enemy_Pos{j}.npy' for k in range(5) for j in range(2)), + 'L5/L5_Enemy_Death.npy', + 'L5/L5_Boss_Pos0.npy', 'L5/L5_Boss_Pos0.npy', + 'L5/L5_Boss_Pos1.npy', 'L5/L5_Boss_Pos1.npy', + ]) + else: + all_enemy_sprites_files.extend([ + *(f'L{i}/L{i}_Enemy_Pos{j}.npy' for j in range(8)), + f'L{i}/L{i}_Enemy_Pos0.npy', f'L{i}/L{i}_Enemy_Pos1.npy', + f'L{i}/L{i}_Enemy_Death.npy', + f'L{i}/L{i}_Boss_Pos0.npy', f'L{i}/L{i}_Boss_Pos0.npy', + f'L{i}/L{i}_Boss_Pos1.npy', f'L{i}/L{i}_Boss_Pos1.npy', + ]) + + return ( + # Procedural background (empty black screen) + {'name': 'background', 'type': 'background', 'data': jnp.zeros((210, 160, 4), dtype=jnp.uint8)}, + # Procedural pixel to ensure white is in the palette + {'name': 'white_pixel', 'type': 'procedural', 'data': jnp.array([[[255,255,255,255]]], dtype=jnp.uint8)}, + # General Sprites (Single) + {'name': 'top_wall', 'type': 'single', 'file': 'L-All/Top.npy'}, + {'name': 'bottom_wall', 'type': 'single', 'file': 'L-All/Bottom.npy'}, + {'name': 'respawn_bottom_wall', 'type': 'single', 'file': 'L-All/Respawn_Bottom.npy'}, + {'name': 'start_screen', 'type': 'single', 'file': 'L-All/First.npy'}, + {'name': 'player_life', 'type': 'single', 'file': 'L-All/Player_Life.npy'}, + {'name': 'black_line', 'type': 'single', 'file': 'L-All/BlackLine.npy'}, + # General Sprites (Group) + {'name': 'transition_bar', 'type': 'group', 'files': ['L-All/TeleportBar.npy', 'L-All/TeleportBar2.npy']}, + # General Sprites (Digits) + {'name': 'digits', 'type': 'digits', 'pattern': 'L-All/Digit{}.npy'}, + # --- Level-Dependent Groups --- + {'name': 'all_clouds', 'type': 'group', 'files': [f'L{i}/L{i}_Cloud.npy' for i in order]}, + {'name': 'all_backgrounds', 'type': 'group', 'files': [f'L{i}/L{i}_Background.npy' for i in order]}, + {'name': 'all_respawn_top_walls', 'type': 'group', 'files': [f'L{i}/L{i}_Top.npy' for i in order]}, + {'name': 'all_player_missiles', 'type': 'group', 'files': [f'L{i}/L{i}_Player_Bullet.npy' for i in order]}, + {'name': 'all_enemy_missiles', 'type': 'group', 'files': [f'L{i}/L{i}_Enemy_Bullet.npy' for i in order]}, + {'name': 'all_enemy_remaining', 'type': 'group', 'files': [ + item for i in order for item in (f'L{i}/L{i}_Enemy_Life.npy', f'L{i}/L{i}_Enemy_Death_Life.npy') + ]}, + # Massive groups + {'name': 'all_player_sprites', 'type': 'group', 'files': all_player_sprites_files}, + {'name': 'all_enemy_sprites', 'type': 'group', 'files': all_enemy_sprites_files}, + ) + +class ReverseChronologyMod(JaxAtariInternalModPlugin): + """Changes the order of appearance of the enemies by reversing the level sequence (reverse chronology).""" + name = "reverse_chronology" + + _levels = [ + TimePilot_Level_1, + TimePilot_Level_2, + TimePilot_Level_3, + TimePilot_Level_4, + TimePilot_Level_5 + ] + + _order = [5, 4, 3, 2, 1] + + constants_overrides = { + "LEVEL_1": _levels[_order[0]-1], + "LEVEL_2": _levels[_order[1]-1], + "LEVEL_3": _levels[_order[2]-1], + "LEVEL_4": _levels[_order[3]-1], + "LEVEL_5": _levels[_order[4]-1], + "ASSET_CONFIG": get_ordered_asset_config(_order) + } + +class InstantTurnMod(JaxAtariInternalModPlugin): + """Directly places the plane in the direction given by the action instead of progressively turning it, including diagonals.""" + name = "instant_turn" + + # Override the game's ACTION_SET to include diagonal actions + attribute_overrides = { + "ACTION_SET": jnp.array( + [ + Action.NOOP, + Action.FIRE, + Action.UP, + Action.RIGHT, + Action.LEFT, + Action.DOWN, + Action.UPRIGHT, + Action.UPLEFT, + Action.DOWNRIGHT, + Action.DOWNLEFT, + Action.UPFIRE, + Action.RIGHTFIRE, + Action.LEFTFIRE, + Action.DOWNFIRE, + Action.UPRIGHTFIRE, + Action.UPLEFTFIRE, + Action.DOWNRIGHTFIRE, + Action.DOWNLEFTFIRE, + ], + dtype=jnp.int32, + ) + } + + @partial(jax.jit, static_argnums=(0,)) + def player_step( + self, + state_player_rotation, + action + ): + # get pressed buttons + left = jnp.logical_or(jnp.logical_or(action == Action.LEFT, action == Action.LEFTFIRE), + jnp.logical_or(jnp.logical_or(action == Action.UPLEFT, action == Action.UPLEFTFIRE), + jnp.logical_or(action == Action.DOWNLEFT, action == Action.DOWNLEFTFIRE))) + right = jnp.logical_or(jnp.logical_or(action == Action.RIGHT, action == Action.RIGHTFIRE), + jnp.logical_or(jnp.logical_or(action == Action.UPRIGHT, action == Action.UPRIGHTFIRE), + jnp.logical_or(action == Action.DOWNRIGHT, action == Action.DOWNRIGHTFIRE))) + up = jnp.logical_or(jnp.logical_or(action == Action.UP, action == Action.UPFIRE), + jnp.logical_or(jnp.logical_or(action == Action.UPLEFT, action == Action.UPLEFTFIRE), + jnp.logical_or(action == Action.UPRIGHT, action == Action.UPRIGHTFIRE))) + down = jnp.logical_or(jnp.logical_or(action == Action.DOWN, action == Action.DOWNFIRE), + jnp.logical_or(jnp.logical_or(action == Action.DOWNLEFT, action == Action.DOWNLEFTFIRE), + jnp.logical_or(action == Action.DOWNRIGHT, action == Action.DOWNRIGHTFIRE))) + + # determine new rotation according to action (up=0, up-left=1, left=2, down-left=3, down=4, down-right=5, right=6, up-right=7) + new_rotation = jax.lax.cond( + up, + lambda: jax.lax.cond(left, lambda: 1, lambda: jax.lax.cond(right, lambda: 7, lambda: 0)), + lambda: jax.lax.cond( + down, + lambda: jax.lax.cond(left, lambda: 3, lambda: jax.lax.cond(right, lambda: 5, lambda: 4)), + lambda: jax.lax.cond( + left, lambda: 2, lambda: jax.lax.cond(right, lambda: 6, lambda: state_player_rotation) + ) + ) + ) + + return jax.lax.cond( + jnp.logical_or(jnp.logical_or(up, down), jnp.logical_or(right, left)), + lambda: new_rotation, + lambda: state_player_rotation + ) + +class DontKillMod(JaxAtariInternalModPlugin, JaxAtariPostStepModPlugin): + """Punishes for killing and shooting enemies.""" + name = "dont_kill" + constants_overrides = { + "POINTS_PER_ENEMY": -1000, + "POINTS_PER_BOSS": -5000 + } + + @partial(jax.jit, static_argnums=(0,)) + def run(self, prev_state, new_state): + # Punish for shooting + # In TimePilot, player_missile_state[3] is the step counter for the missile. + # If it's > 0, the missile is active and flying. + is_shooting = new_state.player_missile_state[3] > 0 + shooting_punishment = jnp.where(is_shooting, 2, 0).astype(jnp.int32) + + # Reward 10 points every 100 frames + survival_reward = jnp.where( + jnp.logical_and(new_state.step_counter > 0, new_state.step_counter % 100 == 0), + 10, + 0 + ).astype(jnp.int32) + + new_score = new_state.score - shooting_punishment + survival_reward + + return new_state.replace(score=new_score) + + @partial(jax.jit, static_argnums=(0,)) + def _get_reward(self, previous_state, state): + # Calculate signed difference directly as score is int32 and won't wrap around + diff = state.score - previous_state.score + return diff.astype(jnp.float32) + + @partial(jax.jit, static_argnums=(0,)) + def _get_env_reward(self, previous_state, state): + return self._get_reward(previous_state, state) + +class MatrixMod(JaxAtariInternalModPlugin): + """A Matrix-themed mod for TimePilot: black background and green everything.""" + name = "matrix_theme" + constants_overrides = { + 'RGB_BACKGROUND': (0, 0, 0), + 'RGB_PLAYER': (255, 255, 255), # Keeping player white for visibility + 'RGB_ENEMY': (0, 200, 0), + 'RGB_BOSS': (0, 255, 0), + 'RGB_CLOUD': (0, 100, 0), + 'RGB_MISSILE': (0, 255, 100), + 'RGB_UI': (0, 255, 0), + 'RGB_SCORE': (0, 0, 0), + 'RGB_LIVES': (0, 0, 0), + 'RGB_BOTTOM_WALL': (0, 0, 0), + } + +class ExtraLivesMod(JaxAtariPostStepModPlugin): + """Start with many extra lives.""" + name = "extra_lives" + + constants_overrides = { + "INITIAL_LIVES": 9, + "MAX_LIVES": 9 + } + + @partial(jax.jit, static_argnums=(0,)) + def after_reset(self, obs, state): + return obs, state.replace(lives=jnp.array(9, dtype=jnp.int32)) + diff --git a/src/jaxatari/games/mods/timepilot_mods.py b/src/jaxatari/games/mods/timepilot_mods.py new file mode 100644 index 000000000..ae3a7e281 --- /dev/null +++ b/src/jaxatari/games/mods/timepilot_mods.py @@ -0,0 +1,38 @@ +import os +from jaxatari.modification import JaxAtariModController +from jaxatari.games.mods.timepilot.timepilot_mod_plugins import ( + DontKillMod, + MatrixMod, + ExtraLivesMod, + InstantTurnMod, + ReverseChronologyMod +) + +class TimePilotEnvMod(JaxAtariModController): + """ + Game-specific Mod Controller for TimePilot. + It simply inherits all logic from JaxAtariModController and defines the REGISTRY. + """ + + REGISTRY = { + "dont_kill": DontKillMod, + "matrix_theme": MatrixMod, + "extra_lives": ExtraLivesMod, + "instant_turn": InstantTurnMod, + "reverse_chronology": ReverseChronologyMod, + } + + _mod_sprite_dir = os.path.join(os.path.dirname(__file__), "timepilot", "sprites") + + def __init__(self, + env, + mods_config: list = [], + allow_conflicts: bool = False + ): + + super().__init__( + env=env, + mods_config=mods_config, + allow_conflicts=allow_conflicts, + registry=self.REGISTRY + ) From fcb159ed6deab479f618607f3c28992bf77b5312 Mon Sep 17 00:00:00 2001 From: Quentin Delfosse Date: Tue, 5 May 2026 22:25:59 +0200 Subject: [PATCH 20/28] Kangaroo and Seaquest reward mods --- scripts/play.py | 3 +- src/jaxatari/games/jax_seaquest.py | 19 ++++++-- .../mods/kangaroo/kangaroo_mod_plugins.py | 48 +++++++++++++++++++ src/jaxatari/games/mods/kangaroo_mods.py | 3 +- .../mods/seaquest/seaquest_mod_plugins.py | 25 ++++++++++ src/jaxatari/games/mods/seaquest_mods.py | 3 +- 6 files changed, 94 insertions(+), 7 deletions(-) diff --git a/scripts/play.py b/scripts/play.py index 7b6c3824f..63d131038 100644 --- a/scripts/play.py +++ b/scripts/play.py @@ -426,7 +426,8 @@ def running_fn(): if not frame_by_frame or next_frame_asked: obs, state, reward, done, info = jitted_step(state, action) - # print(reward) + if reward != 0: + print(reward) total_return += reward if next_frame_asked: next_frame_asked = False diff --git a/src/jaxatari/games/jax_seaquest.py b/src/jaxatari/games/jax_seaquest.py index 2a52ca490..1b4f26d51 100644 --- a/src/jaxatari/games/jax_seaquest.py +++ b/src/jaxatari/games/jax_seaquest.py @@ -2816,12 +2816,23 @@ def render(self, state: SeaquestState) -> jnp.ndarray: # --- UI Elements (Unchanged) --- max_score_digits = 6 - score_digits = self.jr.int_to_digits(state.score, max_digits=max_score_digits) - clamped_score = jnp.minimum(jnp.maximum(state.score, 0), 10**max_score_digits - 1) - score_digit_thresholds = jnp.array([1, 10, 100, 1000, 10000, 100000], dtype=clamped_score.dtype) - num_score_digits = jnp.maximum(1, jnp.sum(clamped_score >= score_digit_thresholds)) + abs_score = jnp.abs(state.score) + score_digits = self.jr.int_to_digits(abs_score, max_digits=max_score_digits) + clamped_abs_score = jnp.minimum(abs_score, 10**max_score_digits - 1) + score_digit_thresholds = jnp.array([1, 10, 100, 1000, 10000, 100000], dtype=clamped_abs_score.dtype) + num_score_digits = jnp.maximum(1, jnp.sum(clamped_abs_score >= score_digit_thresholds)) score_start_index = max_score_digits - num_score_digits score_x = 59 + score_start_index * 8 + + # Render negative sign if needed + is_negative = state.score < 0 + raster = jax.lax.cond( + is_negative, + lambda r: self.jr.render_at(r, score_x - 8, 9, self.SHAPE_MASKS['digits'][10]), + lambda r: r, + raster + ) + raster = self.jr.render_label_selective( raster, score_x, diff --git a/src/jaxatari/games/mods/kangaroo/kangaroo_mod_plugins.py b/src/jaxatari/games/mods/kangaroo/kangaroo_mod_plugins.py index 1e2a94de6..2fc342de4 100644 --- a/src/jaxatari/games/mods/kangaroo/kangaroo_mod_plugins.py +++ b/src/jaxatari/games/mods/kangaroo/kangaroo_mod_plugins.py @@ -89,6 +89,54 @@ def _monkey_controller(self, state: KangarooState, punching: chex.Array): jnp.array(False), ) +class DontPunchMod(JaxAtariInternalModPlugin): + """ + Internal mod that provides negative reward for punching monkeys. + """ + @partial(jax.jit, static_argnums=(0,)) + def _get_reward(self, previous_state: KangarooState, state: KangarooState) -> float: + # Standard reward + reward = state.score - previous_state.score + + punching = state.player.punch_left | state.player.punch_right + + # Fist position (re-calculate as in _monkey_controller) + fist_w = 3 + fist_h = 4 + fist_x = jnp.where( + state.player.orientation > 0, + state.player.x + self._env.consts.PLAYER_WIDTH, + state.player.x - fist_w, + ) + fist_y = state.player.y + 8 + + def check_punch(f_x, f_y, f_w, f_h, m_x, m_y, m_w, m_h, m_state, punching): + return jnp.logical_and( + self._env._entities_collide(f_x, f_y, f_w, f_h, m_x, m_y, m_w, m_h), + jnp.logical_and(m_state != 0, punching), + ) + + monkeys_punched = jax.vmap( + check_punch, + in_axes=(None, None, None, None, 0, 0, None, None, 0, None), + )( + fist_x, + fist_y, + fist_w, + fist_h, + previous_state.level.monkey_positions[:, 0], + previous_state.level.monkey_positions[:, 1], + self._env.consts.MONKEY_WIDTH, + self._env.consts.MONKEY_HEIGHT, + previous_state.level.monkey_states, + punching, + ) + + num_punched = jnp.sum(monkeys_punched) + # The game already gives +200 per monkey. + # To make it net negative (e.g. -200), we subtract 400. + return reward - num_punched * 400.0 + class NoFallingCoconutMod(JaxAtariInternalModPlugin): """ Internal mod to disable the single falling coconut. diff --git a/src/jaxatari/games/mods/kangaroo_mods.py b/src/jaxatari/games/mods/kangaroo_mods.py index 65144a790..7506b2f04 100644 --- a/src/jaxatari/games/mods/kangaroo_mods.py +++ b/src/jaxatari/games/mods/kangaroo_mods.py @@ -8,7 +8,7 @@ FirstLevelOnlyMod, SecondLevelOnlyMod, ThirdLevelOnlyMod, FourLaddersMod, ReplaceCoconutWithFireball, ReplaceCoconutWithHoneyBee, ReplaceCoconutWithWasp, ReplaceMonkeyWithChickenMod, ReplaceMonkeyWithDragonMod, ReplaceMonkeyWithDangerSignMod, ReplaceMonkeyWithPolarbearMod, ReplaceMonkeyWithSnakeMod, ReplaceBellWithDangerSignMod, - ReplaceFruitWithCoin, ReplaceFruitWithDiamond + ReplaceFruitWithCoin, ReplaceFruitWithDiamond, DontPunchMod ) # --- 3. The Registry --- @@ -16,6 +16,7 @@ "no_bell": NoBellMod, "no_fruit": NoFruitMod, "no_monkey": NoMonkeyMod, + "dont_punch": DontPunchMod, "no_falling_coconut": NoFallingCoconutMod, "no_thrown_coconut": NoThrownCoconutMod, "high_thrown_coconuts": AlwaysHighCoconutMod, diff --git a/src/jaxatari/games/mods/seaquest/seaquest_mod_plugins.py b/src/jaxatari/games/mods/seaquest/seaquest_mod_plugins.py index 966c30301..4c7d3ea3e 100644 --- a/src/jaxatari/games/mods/seaquest/seaquest_mod_plugins.py +++ b/src/jaxatari/games/mods/seaquest/seaquest_mod_plugins.py @@ -120,4 +120,29 @@ def run(self, prev_state: SeaquestState, new_state: SeaquestState) -> SeaquestSt class RandomColorEnemiesMod(JaxAtariInternalModPlugin): pass +class DontKillMod(JaxAtariInternalModPlugin): + """ + Internal mod that punishes killing and shooting. + """ + @partial(jax.jit, static_argnums=(0,)) + def calculate_kill_points(self, successful_rescues: chex.Array) -> chex.Array: + # Punish killing by returning a negative value + return jnp.array(-100, dtype=jnp.int32) + + @partial(jax.jit, static_argnums=(0,)) + def _get_reward(self, previous_state: SeaquestState, state: SeaquestState): + # Standard reward (which now includes the negative kill points from calculate_kill_points) + reward = state.score - previous_state.score + + # Punish shooting + # A shot is fired when a missile is newly active (state[2] != 0) + shot_fired = jnp.logical_and( + previous_state.player_missile_position[2] == 0, + state.player_missile_position[2] != 0 + ) + # We use a penalty for every shot fired + shooting_penalty = jnp.where(shot_fired, 10.0, 0.0) + + return reward.astype(jnp.float32) - shooting_penalty + diff --git a/src/jaxatari/games/mods/seaquest_mods.py b/src/jaxatari/games/mods/seaquest_mods.py index 3b777be1d..bc8371907 100644 --- a/src/jaxatari/games/mods/seaquest_mods.py +++ b/src/jaxatari/games/mods/seaquest_mods.py @@ -3,7 +3,7 @@ import jax.numpy as jnp from functools import partial from jaxatari.modification import JaxAtariModController -from jaxatari.games.mods.seaquest.seaquest_mod_plugins import DisableEnemiesMod, NoDiversMod, EnemyMinesMod, FireBallsMod, UnlimitedOxygenMod, GravityMod, RandomColorEnemiesMod +from jaxatari.games.mods.seaquest.seaquest_mod_plugins import DisableEnemiesMod, NoDiversMod, EnemyMinesMod, FireBallsMod, UnlimitedOxygenMod, GravityMod, RandomColorEnemiesMod, DontKillMod class SeaquestEnvMod(JaxAtariModController): """ @@ -22,6 +22,7 @@ class SeaquestEnvMod(JaxAtariModController): "random_color_enemies": RandomColorEnemiesMod, # "polluted_water": PollutedWaterMod, "mines": EnemyMinesMod, + "dont_kill": DontKillMod, # "fireball": ReplaceTorpedoWithFireBallMod } From 62e3a79ffccc6598291828203d2d17ae5f0f3179 Mon Sep 17 00:00:00 2001 From: Quentin Delfosse Date: Tue, 5 May 2026 22:59:38 +0200 Subject: [PATCH 21/28] Fixing Kangaroo negative score --- src/jaxatari/games/jax_kangaroo.py | 20 ++++++++++++- .../mods/kangaroo/kangaroo_mod_plugins.py | 29 ++++++++++--------- 2 files changed, 35 insertions(+), 14 deletions(-) diff --git a/src/jaxatari/games/jax_kangaroo.py b/src/jaxatari/games/jax_kangaroo.py index f555f72ee..99a6fca57 100644 --- a/src/jaxatari/games/jax_kangaroo.py +++ b/src/jaxatari/games/jax_kangaroo.py @@ -2354,9 +2354,27 @@ def _draw_coco(i, current_raster): # --- 4. Draw UI --- # Score - score_digits = self.jr.int_to_digits(state.score, max_digits=6) + is_negative = state.score < 0 + score_digits = self.jr.int_to_digits(jnp.abs(state.score), max_digits=6) raster = self.jr.render_label(raster, 105, 182, score_digits, self.SHAPE_MASKS["score_digits"], spacing=8, max_digits=6) + # Draw minus sign if score is negative, in front of the 6 digits (left side) + def draw_minus(r): + transparent_id = self.jr.TRANSPARENT_ID + mask_0 = self.SHAPE_MASKS["score_digits"][0] + # The digit color is the one that is not transparent + digit_color = jnp.where(mask_0 != transparent_id, mask_0, jnp.inf).min().astype(mask_0.dtype) + minus_mask = jnp.full((7, 7), transparent_id, dtype=mask_0.dtype) + minus_mask = minus_mask.at[3, 2:5].set(digit_color) + return self.jr.render_at(r, 105 - 8, 182, minus_mask) + + raster = jax.lax.cond( + is_negative, + draw_minus, + lambda r: r, + raster + ) + # Lives lives_count = jnp.maximum(state.lives.astype(int) - 1, 0) raster = self.jr.render_indicator(raster, 15, 182, lives_count, self.SHAPE_MASKS["lives"], spacing=8, max_value=5) diff --git a/src/jaxatari/games/mods/kangaroo/kangaroo_mod_plugins.py b/src/jaxatari/games/mods/kangaroo/kangaroo_mod_plugins.py index 2fc342de4..010ea0bf8 100644 --- a/src/jaxatari/games/mods/kangaroo/kangaroo_mod_plugins.py +++ b/src/jaxatari/games/mods/kangaroo/kangaroo_mod_plugins.py @@ -89,26 +89,24 @@ def _monkey_controller(self, state: KangarooState, punching: chex.Array): jnp.array(False), ) -class DontPunchMod(JaxAtariInternalModPlugin): +class DontPunchMod(JaxAtariPostStepModPlugin): """ - Internal mod that provides negative reward for punching monkeys. + Post-step mod that provides negative reward for punching monkeys. + It recalculates if a monkey was punched and updates the game score. """ @partial(jax.jit, static_argnums=(0,)) - def _get_reward(self, previous_state: KangarooState, state: KangarooState) -> float: - # Standard reward - reward = state.score - previous_state.score - + def run(self, previous_state: KangarooState, state: KangarooState) -> KangarooState: punching = state.player.punch_left | state.player.punch_right # Fist position (re-calculate as in _monkey_controller) fist_w = 3 fist_h = 4 fist_x = jnp.where( - state.player.orientation > 0, - state.player.x + self._env.consts.PLAYER_WIDTH, - state.player.x - fist_w, + previous_state.player.orientation > 0, + previous_state.player.x + self._env.consts.PLAYER_WIDTH, + previous_state.player.x - fist_w, ) - fist_y = state.player.y + 8 + fist_y = previous_state.player.y + 8 def check_punch(f_x, f_y, f_w, f_h, m_x, m_y, m_w, m_h, m_state, punching): return jnp.logical_and( @@ -133,9 +131,14 @@ def check_punch(f_x, f_y, f_w, f_h, m_x, m_y, m_w, m_h, m_state, punching): ) num_punched = jnp.sum(monkeys_punched) - # The game already gives +200 per monkey. - # To make it net negative (e.g. -200), we subtract 400. - return reward - num_punched * 400.0 + + # The base game already added +200 per monkey to `state.score`. + # To make it -200 per monkey, we subtract 400. + penalty = num_punched * 400 + + # Update the score + new_score = state.score - penalty + return state.replace(score=new_score) class NoFallingCoconutMod(JaxAtariInternalModPlugin): """ From ce0287dfa5be1a22d7b55ca96695687cf150602b Mon Sep 17 00:00:00 2001 From: Quentin Delfosse Date: Wed, 6 May 2026 00:20:19 +0200 Subject: [PATCH 22/28] Various mods implemented --- src/jaxatari/games/jax_beamrider.py | 44 +++- src/jaxatari/games/jax_frostbite.py | 223 +++++++++--------- .../mods/beamrider/beamrider_mod_plugins.py | 32 +++ src/jaxatari/games/mods/beamrider_mods.py | 2 + .../games/mods/freeway/freeway_mod_plugins.py | 26 ++ src/jaxatari/games/mods/freeway_mods.py | 3 +- .../mods/frostbite/frostbite_mod_plugins.py | 8 + src/jaxatari/games/mods/frostbite_mods.py | 3 +- 8 files changed, 222 insertions(+), 119 deletions(-) diff --git a/src/jaxatari/games/jax_beamrider.py b/src/jaxatari/games/jax_beamrider.py index 6426be75e..9699bca0f 100644 --- a/src/jaxatari/games/jax_beamrider.py +++ b/src/jaxatari/games/jax_beamrider.py @@ -365,6 +365,19 @@ class BeamriderConstants(struct.PyTreeNode): COIN_ANIM_SEQ: Tuple[int, ...] = struct.field(pytree_node=False, default=(3, 2, 1, 0, 1, 2)) COIN_SPRITE_SIZE: Tuple[int, int] = struct.field(pytree_node=False, default=(7, 8)) + UFO_REWARD: int = struct.field(pytree_node=False, default=40) + UFO_SECTOR_REWARD: int = struct.field(pytree_node=False, default=4) + BOUNCER_REWARD: int = struct.field(pytree_node=False, default=80) + REJUVENATOR_REWARD: int = struct.field(pytree_node=False, default=150) + COIN_REWARD: int = struct.field(pytree_node=False, default=300) + COIN_SECTOR_REWARD: int = struct.field(pytree_node=False, default=30) + COIN_LIFE_REWARD: int = struct.field(pytree_node=False, default=100) + COIN_LIFE_SECTOR_REWARD: int = struct.field(pytree_node=False, default=10) + MOTHERSHIP_REWARD: int = struct.field(pytree_node=False, default=300) + MOTHERSHIP_SECTOR_REWARD: int = struct.field(pytree_node=False, default=30) + MOTHERSHIP_LIFE_REWARD: int = struct.field(pytree_node=False, default=100) + MOTHERSHIP_LIFE_SECTOR_REWARD: int = struct.field(pytree_node=False, default=10) + def _get_index_ufo(pos: chex.Array) -> chex.Array: return _UFO_INDEX_TABLE[jnp.clip(pos.astype(jnp.int32), 0, 800)] @@ -1440,7 +1453,7 @@ def _collisions_step( bouncer_pos = jnp.where(bouncer_destroyed, self.enemy_offscreen, bouncer_pos) bouncer_active = jnp.where(bouncer_destroyed, False, bouncer_active) player_shot_pos = jnp.where(bouncer_hit, self.bullet_offscreen, player_shot_pos) - score = jnp.where(bouncer_destroyed, score + 80, score) + score = jnp.where(bouncer_destroyed, score + self.consts.BOUNCER_REWARD, score) # Meteoroid bullet collision pre_collision_meteoroid_pos = chasing_meteoroid_pos @@ -1492,7 +1505,7 @@ def _collisions_step( coin_active = jnp.where(hit_mask_coin, False, coin_active) player_shot_pos = jnp.where(hit_exists_coin, self.bullet_offscreen, player_shot_pos) clamped_sector = jnp.minimum(state.sector, 89) - score = jnp.where(hit_exists_coin, score + 300 + 30 * clamped_sector + jnp.maximum(state.lives - 1, 0) * (100 + 10 * clamped_sector), score) + score = jnp.where(hit_exists_coin, score + self.consts.COIN_REWARD + self.consts.COIN_SECTOR_REWARD * clamped_sector + jnp.maximum(state.lives - 1, 0) * (self.consts.COIN_LIFE_REWARD + self.consts.COIN_LIFE_SECTOR_REWARD * clamped_sector), score) # Rejuvenator bullet collision rejuv_hit, rejuv_destroyed = self._rejuvenator_bullet_collision(rejuv_pos, rejuv_active, rejuv_dead, player_shot_pos, @@ -1502,13 +1515,13 @@ def _collisions_step( rejuv_active = jnp.where(rejuv_destroyed, False, rejuv_active) rejuv_pos = jnp.where(rejuv_destroyed, self.enemy_offscreen, rejuv_pos) player_shot_pos = jnp.where(rejuv_hit, self.bullet_offscreen, player_shot_pos) - score = jnp.where(rejuv_destroyed, score + 150, score) + score = jnp.where(rejuv_destroyed, score + self.consts.REJUVENATOR_REWARD, score) # Mothership bullet collision hit_mothership = self._mothership_bullet_collision(state.level.mothership_stage, state.level.mothership_position, player_shot_pos, shot_x_screen, shot_active, bullet_size, is_torpedo) player_shot_pos = jnp.where(hit_mothership, self.bullet_offscreen, player_shot_pos) - score = jnp.where(hit_mothership, score + 300 + 30 * clamped_sector + jnp.maximum(state.lives - 1, 0) * (100 + 10 * clamped_sector), score) + score = jnp.where(hit_mothership, score + self.consts.MOTHERSHIP_REWARD + self.consts.MOTHERSHIP_SECTOR_REWARD * clamped_sector + jnp.maximum(state.lives - 1, 0) * (self.consts.MOTHERSHIP_LIFE_REWARD + self.consts.MOTHERSHIP_LIFE_SECTOR_REWARD * clamped_sector), score) # Enemy shot collision hit_mask_shot, hit_exists_shot = self._enemy_shot_bullet_collision(enemy_shot_pos, enemy_shot_timer, player_shot_pos, @@ -2146,7 +2159,7 @@ def _collision_handler( state.level.white_ufo_left, ) clamped_sector = jnp.minimum(state.sector, 89) - ufo_score = 40 + 4 * clamped_sector + ufo_score = self.consts.UFO_REWARD + self.consts.UFO_SECTOR_REWARD * clamped_sector score = jnp.where(hit_exists_ufo, state.score + ufo_score, state.score) return (enemy_pos, player_shot_pos, new_patterns, new_timers, new_spawn_delays, white_ufo_left, score, hit_mask_ufo, hit_exists_ufo) @@ -4600,11 +4613,28 @@ def _render_white_ufo_counter(self, raster, state): ) def _render_score(self, raster, state): - yellow_digits = self.jr.int_to_digits(state.score, max_digits=6) + abs_score = jnp.abs(state.score) + yellow_digits = self.jr.int_to_digits(abs_score, max_digits=6) score_masks = self.SHAPE_MASKS["yellow_numbers"] - return self.jr.render_label_selective( + raster = self.jr.render_label_selective( raster, 61, 10, yellow_digits, score_masks, 0, 6, spacing=8, max_digits_to_render=6 ) + + # Add negative sign if score < 0 + is_negative = state.score < 0 + digit_mask = score_masks[0] + minus_color_id = jnp.max(digit_mask) + minus_mask = jnp.zeros_like(digit_mask) + mid_y = digit_mask.shape[0] // 2 + minus_mask = minus_mask.at[mid_y, 1:-1].set(minus_color_id) + + raster = jax.lax.cond( + is_negative, + lambda r: self.jr.render_at(r, 61 - 8, 10, minus_mask), + lambda r: r, + raster + ) + return raster def _render_sector(self, raster, state): sector_digits = self.jr.int_to_digits(state.sector, max_digits=2) diff --git a/src/jaxatari/games/jax_frostbite.py b/src/jaxatari/games/jax_frostbite.py index 8566139b4..c7e6b0c67 100644 --- a/src/jaxatari/games/jax_frostbite.py +++ b/src/jaxatari/games/jax_frostbite.py @@ -126,6 +126,12 @@ class FrostbiteConstants(struct.PyTreeNode): MAX_EATEN_FISH: int = struct.field(pytree_node=False, default=12) # Max fish that can be eaten per level MAX_RESERVED_LIVES: int = struct.field(pytree_node=False, default=9) # Maximum reserve lives + # Rewards + REWARD_ICE_BLOCK: int = struct.field(pytree_node=False, default=10) # Multiplied by level + REWARD_FISH: int = struct.field(pytree_node=False, default=200) + REWARD_IGLOO_BLOCK: int = struct.field(pytree_node=False, default=10) # Multiplied by level + REWARD_TEMPERATURE: int = struct.field(pytree_node=False, default=10) # Multiplied by level + # Status Masks OBSTACLE_DIR_MASK: int = struct.field(pytree_node=False, default=0x80) ICE_BLOCK_DIR_MASK: int = struct.field(pytree_node=False, default=0x40) @@ -421,6 +427,9 @@ def lives(self): ice_segments_x: chex.Array ice_segments_w: chex.Array + # Score sign (1 for positive, -1 for negative) + score_sign: chex.Array + # JAX rng_key: chex.PRNGKey @@ -490,14 +499,15 @@ def __init__(self, consts: FrostbiteConstants = None): self.renderer = FrostbiteRenderer(self.consts) - def _get_point_value_for_level(self, level: jnp.ndarray): - """Get point value for level × 10 in BCD format""" - points = level * 10 - points = jnp.minimum(points, 90) - hundreds = points // 100 - tens = (points % 100) // 10 - ones = points % 10 - return (hundreds << 8) | (tens << 4) | ones + def _get_reward_decimal(self, level: jnp.ndarray, base_reward: int): + """Get point value for level × base_reward""" + points = level * base_reward + # Original game caps the level multiplier at 9 + return jnp.minimum(points, 9 * base_reward) + + def _get_fish_reward_decimal(self): + """Get fish reward""" + return jnp.int32(self.consts.REWARD_FISH) def _get_obstacle_pattern_mask(self, level: jnp.ndarray): """Get pattern mask to control obstacle density based on level. @@ -724,6 +734,7 @@ def reset(self, key: jax.random.PRNGKey = None) -> Tuple[FrostbiteObservation, F reserve_lives=jnp.array(self.consts.INIT_LIVES, dtype=jnp.int32), ice_segments_x=jnp.full((4, 6), self.consts.ICE_UNUSED_POS, dtype=jnp.int32), ice_segments_w=jnp.zeros((4, 6), dtype=jnp.int32), + score_sign=jnp.array(1, dtype=jnp.int32), rng_key=key ) # Spawn all 4 obstacles at the start @@ -820,6 +831,7 @@ def reset(self, key: jax.random.PRNGKey = None) -> Tuple[FrostbiteObservation, F reserve_lives=jnp.array(self.consts.INIT_LIVES, dtype=jnp.int32), ice_segments_x=jnp.full((4, 6), self.consts.ICE_UNUSED_POS, dtype=jnp.int32), ice_segments_w=jnp.zeros((4, 6), dtype=jnp.int32), + score_sign=jnp.array(1, dtype=jnp.int32), rng_key=key ) seg_x, seg_w = self._compute_ice_segments(state) @@ -912,7 +924,7 @@ def _get_observation(self, state: FrostbiteState) -> FrostbiteObservation: hits = active & (px >= seg_x) & (px < seg_x + seg_w) # Shape: (4, 6, 16) ice_grid = jnp.any(hits, axis=1).astype(jnp.int32) # Shape: (4, 16) - score_val = self._bcd_to_decimal(state.score) + score_val = self._bcd_to_decimal(state.score) * state.score_sign temp_val = self._bcd_to_decimal(jnp.array([0, 0, state.temperature], dtype=jnp.int32)) return FrostbiteObservation( @@ -1124,8 +1136,8 @@ def sel(old, new): return jnp.where(spawn_mask, new, old) @partial(jax.jit, static_argnums=(0,)) def _get_reward(self, previous_state: FrostbiteState, state: FrostbiteState) -> chex.Array: """Calculate reward based on score difference.""" - prev_score_val = self._bcd_to_decimal(previous_state.score) - curr_score_val = self._bcd_to_decimal(state.score) + prev_score_val = self._bcd_to_decimal(previous_state.score) * previous_state.score_sign + curr_score_val = self._bcd_to_decimal(state.score) * state.score_sign return (curr_score_val - prev_score_val).astype(jnp.float32) @partial(jax.jit, static_argnums=(0,)) @@ -1285,11 +1297,11 @@ def _process_level_complete(self, state: FrostbiteState): should_remove_block, state.building_igloo_idx - 1, state.building_igloo_idx ) - # Award points for each removed block (level × 10 in BCD) - point_value = self._get_point_value_for_level(state.level) - block_points = jnp.where(should_remove_block, point_value, 0) - new_score = self._add_bcd_score(state.score, block_points) - new_remaining_lives = self._check_extra_life(state.score, new_score, state.remaining_lives) + # Award points for each removed block (level × base in BCD) + block_reward = self._get_reward_decimal(state.level, self.consts.REWARD_IGLOO_BLOCK) + block_points = jnp.where(should_remove_block, block_reward, 0) + new_score, new_score_sign = self._add_score_decimal(state.score, state.score_sign, block_points) + new_remaining_lives = self._check_extra_life(state.score, state.score_sign, new_score, new_score_sign, state.remaining_lives) # When all blocks are removed, start temperature countdown phase blocks_just_finished = is_level_complete & (state.building_igloo_idx == 0) & (new_building_idx < 0) @@ -1316,10 +1328,12 @@ def decrement_temp_bcd(temp): new_temperature = jnp.where(should_decrement_temp, decrement_temp_bcd(state.temperature), state.temperature) # Award points for each temperature degree (same as block points) - temp_points = jnp.where(should_decrement_temp, point_value, 0) + temp_reward = self._get_reward_decimal(state.level, self.consts.REWARD_TEMPERATURE) + temp_points = jnp.where(should_decrement_temp, temp_reward, 0) old_score_for_temp = new_score # Use the already-updated score from block removal - new_score = self._add_bcd_score(new_score, temp_points) - new_remaining_lives = self._check_extra_life(old_score_for_temp, new_score, new_remaining_lives) + old_sign_for_temp = new_score_sign + new_score, new_score_sign = self._add_score_decimal(new_score, new_score_sign, temp_points) + new_remaining_lives = self._check_extra_life(old_score_for_temp, old_sign_for_temp, new_score, new_score_sign, new_remaining_lives) # Phase 3: Reset game state for next level when temperature reaches zero level_reset_complete = ( @@ -1452,6 +1466,7 @@ def decrement_temp_bcd(temp): frame_delay=new_delay, building_igloo_idx=new_building_idx, score=new_score, + score_sign=new_score_sign, remaining_lives=new_remaining_lives, bailey_x=new_bailey_x, bailey_y=new_bailey_y, @@ -2341,16 +2356,15 @@ def check_single_wrap(offset): ) # Award points and igloo blocks for collecting ice - new_score = state.score # No need for .copy() in JAX new_building_idx = state.building_igloo_idx - # Points based on level (level × 10 in BCD) - point_value = self._get_point_value_for_level(state.level) - points = jnp.where(row_changed, point_value, 0) - new_score = self._add_bcd_score(new_score, points) - + # Points based on level + ice_reward = self._get_reward_decimal(state.level, self.consts.REWARD_ICE_BLOCK) + points = jnp.where(row_changed, ice_reward, 0) + new_score, new_score_sign = self._add_score_decimal(state.score, state.score_sign, points) + # Check for extra life - new_remaining_lives = self._check_extra_life(state.score, new_score, new_remaining_lives) + new_remaining_lives = self._check_extra_life(state.score, state.score_sign, new_score, new_score_sign, new_remaining_lives) new_building_idx = jnp.where( row_changed & (state.building_igloo_idx < self.consts.MAX_IGLOO_INDEX), @@ -2407,6 +2421,7 @@ def check_single_wrap(offset): completed_ice_blocks_delay=new_delay, bailey_death_frame=final_death_frame, score=new_score, + score_sign=new_score_sign, building_igloo_idx=new_building_idx, remaining_lives=new_remaining_lives, temperature=new_temperature, @@ -2433,67 +2448,31 @@ def _spawn_all(s): return updated_state - def _add_bcd_score(self, score, points): - """Add points to BCD score (Binary Coded Decimal).""" - current = self._bcd_to_decimal(score) - hundreds = (points >> 8) & 0x0F - tens = (points >> 4) & 0x0F - ones = points & 0x0F - decimal_points = hundreds * 100 + tens * 10 + ones + def _add_score_decimal(self, score, score_sign, decimal_points): + """Add decimal points to BCD score, returning new score and sign.""" + current = self._bcd_to_decimal(score) * score_sign new_total = current + decimal_points - d5 = (new_total // 100000) % 10 - d4 = (new_total // 10000) % 10 - d3 = (new_total // 1000) % 10 - d2 = (new_total // 100) % 10 - d1 = (new_total // 10) % 10 - d0 = new_total % 10 + new_sign = jnp.where(new_total < 0, jnp.int32(-1), jnp.int32(1)) + abs_total = jnp.abs(new_total) + d5 = (abs_total // 100000) % 10 + d4 = (abs_total // 10000) % 10 + d3 = (abs_total // 1000) % 10 + d2 = (abs_total // 100) % 10 + d1 = (abs_total // 10) % 10 + d0 = abs_total % 10 new_score = score.at[0].set((d5 << 4) | d4) new_score = new_score.at[1].set((d3 << 4) | d2) new_score = new_score.at[2].set((d1 << 4) | d0) - return new_score - - def _check_extra_life(self, old_score, new_score, lives): - """Check if score crossed a 5000-point boundary and award extra life. - - In Frostbite, players earn an extra life every 5000 points. This method - checks if the score increase crossed one or more 5000-point thresholds - and awards lives accordingly. Maximum of 9 reserve lives is enforced. - - Args: - old_score: Previous integer score - new_score: Updated integer score after points were added - lives: Current number of reserve lives - - Returns: - Updated number of lives, capped at 9 (hardcoded limit) - """ - # Calculate how many 5000-point thresholds were crossed - # Integer division gives us the number of extra lives earned so far - old_total = self._bcd_to_decimal(old_score) - new_total = self._bcd_to_decimal(new_score) - old_lives_earned = old_total // 5000 - new_lives_earned = new_total // 5000 - - # Award extra lives for all thresholds crossed, but enforce the 9-life maximum - # Note: MAX_RESERVED_LIVES constant exists but isn't used - limit is hardcoded + return new_score, new_sign + + def _check_extra_life(self, old_score, old_sign, new_score, new_sign, lives): + """Check if score crossed a 5000-point boundary and award extra life.""" + old_total = self._bcd_to_decimal(old_score) * old_sign + new_total = self._bcd_to_decimal(new_score) * new_sign + old_lives_earned = jnp.maximum(0, old_total) // 5000 + new_lives_earned = jnp.maximum(0, new_total) // 5000 lives_to_add = new_lives_earned - old_lives_earned - new_lives = jnp.where(lives_to_add > 0, jnp.minimum(lives + lives_to_add, 9), lives) - - return new_lives - - def _add_bcd_score_decimal(self, score, decimal_points): - current = self._bcd_to_decimal(score) - new_total = current + decimal_points - d5 = (new_total // 100000) % 10 - d4 = (new_total // 10000) % 10 - d3 = (new_total // 1000) % 10 - d2 = (new_total // 100) % 10 - d1 = (new_total // 10) % 10 - d0 = new_total % 10 - new_score = score.at[0].set((d5 << 4) | d4) - new_score = new_score.at[1].set((d3 << 4) | d2) - new_score = new_score.at[2].set((d1 << 4) | d0) - return new_score + return jnp.where(lives_to_add > 0, jnp.minimum(lives + lives_to_add, 9), lives) @partial(jax.jit, static_argnums=(0,)) def _check_obstacle_collisions(self, state: FrostbiteState): @@ -2570,15 +2549,20 @@ def alive_bit(mask, k): return ((mask >> jnp.int32(k)) & 1) == 1 hit_is_fish = has_hit & (hit_type == self.consts.ID_FISH) can_award = hit_is_fish & (state.number_of_fish_eaten < self.consts.MAX_EATEN_FISH) - new_score = jax.lax.cond( + def award_fish(score_info): + score, sign = score_info + return self._add_score_decimal(score, sign, self._get_fish_reward_decimal()) + + new_score, new_score_sign = jax.lax.cond( can_award, - lambda s: self._add_bcd_score(s, jnp.int32(0x0200)), - lambda s: s, - state.score + award_fish, + lambda args: args, + (state.score, state.score_sign) ) + new_remaining_lives = jax.lax.cond( can_award, - lambda _: self._check_extra_life(state.score, new_score, state.remaining_lives), + lambda _: self._check_extra_life(state.score, state.score_sign, new_score, new_score_sign, state.remaining_lives), lambda _: state.remaining_lives, operand=None ) @@ -2630,6 +2614,7 @@ def clear_bit(mask, k): fish_alive_mask=new_fish_masks, obstacle_x=new_obstacle_x, score=new_score, + score_sign=new_score_sign, remaining_lives=new_remaining_lives ) @@ -3442,63 +3427,81 @@ def _render_hud(self, raster, state): temp_x = jnp.where(temp_tens > 0, temp_x_base - 16, temp_x_base - 8) # 3. Score - total_score = self._bcd_to_decimal(state.score) + total_score_val = self._bcd_to_decimal(state.score) * state.score_sign + total_score_abs = jnp.abs(total_score_val) + score_is_negative = total_score_val < 0 + num_digits = jnp.where( - total_score == 0, + total_score_abs == 0, 1, - jnp.floor(jnp.log10(jnp.maximum(1, total_score) + 0.5)).astype(jnp.int32) + 1 + jnp.floor(jnp.log10(jnp.maximum(1, total_score_abs) + 0.5)).astype(jnp.int32) + 1 ) score_x_start = lives_x - ((num_digits - 1) * 4) - score_digits = self.jr.int_to_digits(total_score, max_digits=6) + score_digits = self.jr.int_to_digits(total_score_abs, max_digits=6) start_index = 6 - num_digits - + # Collect everything into a batch - # 0: Lives, 1: Temp Tens, 2: Temp Ones, 3: Degree, 4-9: Score Digits - hud_masks = jnp.zeros((10, *digits_masks.shape[1:]), dtype=digits_masks.dtype) - hud_x = jnp.zeros(10, dtype=jnp.int32) - hud_y = jnp.zeros(10, dtype=jnp.int32) - hud_active = jnp.zeros(10, dtype=jnp.bool_) - + # 0: Lives, 1: Temp Tens, 2: Temp Ones, 3: Degree, 4-9: Score Digits, 10: Minus Sign + hud_masks = jnp.zeros((11, *digits_masks.shape[1:]), dtype=digits_masks.dtype) + hud_x = jnp.zeros(11, dtype=jnp.int32) + hud_y = jnp.zeros(11, dtype=jnp.int32) + hud_active = jnp.zeros(11, dtype=jnp.bool_) + # Lives hud_masks = hud_masks.at[0].set(lives_mask) hud_x = hud_x.at[0].set(lives_x) hud_y = hud_y.at[0].set(lives_y) hud_active = hud_active.at[0].set(is_visible) - + # Temp Tens hud_masks = hud_masks.at[1].set(digits_masks[temp_tens]) hud_x = hud_x.at[1].set(temp_x) hud_y = hud_y.at[1].set(temp_y) hud_active = hud_active.at[1].set(is_visible & (temp_tens > 0)) - + # Temp Ones hud_masks = hud_masks.at[2].set(digits_masks[temp_ones]) hud_x = hud_x.at[2].set(jnp.where(temp_tens > 0, temp_x + 8, temp_x)) hud_y = hud_y.at[2].set(temp_y) hud_active = hud_active.at[2].set(is_visible) - + # Degree hud_masks = hud_masks.at[3].set(self.DEGREE_MASK) hud_x = hud_x.at[3].set(temp_x_base) hud_y = hud_y.at[3].set(temp_y) hud_active = hud_active.at[3].set(is_visible) - + # Score Digits score_idx = jnp.arange(6) score_mask_idx = jnp.take(score_digits, start_index + score_idx) score_masks = digits_masks[score_mask_idx] score_x = score_x_start + score_idx * 8 - - hud_masks = hud_masks.at[4:].set(score_masks) - hud_x = hud_x.at[4:].set(score_x) - hud_y = hud_y.at[4:].set(score_y) - hud_active = hud_active.at[4:].set(score_idx < num_digits) - + + hud_masks = hud_masks.at[4:10].set(score_masks) + hud_x = hud_x.at[4:10].set(score_x) + hud_y = hud_y.at[4:10].set(score_y) + hud_active = hud_active.at[4:10].set(score_idx < num_digits) + + # Minus Sign + minus_x = score_x_start - 8 + minus_y = score_y + text_color_idx = jnp.max(digits_masks[0]) + h, w = digits_masks.shape[1], digits_masks.shape[2] + yy, xx = jnp.mgrid[:h, :w] + minus_mask = jnp.where( + (yy >= h // 2 - 1) & (yy <= h // 2) & (xx > 1) & (xx < w - 1), + text_color_idx, + self.jr.TRANSPARENT_ID + ) + hud_masks = hud_masks.at[10].set(minus_mask) + hud_x = hud_x.at[10].set(minus_x) + hud_y = hud_y.at[10].set(minus_y) + hud_active = hud_active.at[10].set(score_is_negative) + # Final filtering hud_masks = jnp.where(hud_active[:, jnp.newaxis, jnp.newaxis], hud_masks, self.jr.TRANSPARENT_ID) - - return self.jr.render_at_batch(raster, hud_x, hud_y, hud_masks) + return self.jr.render_at_batch(raster, hud_x, hud_y, hud_masks) @partial(jax.jit, static_argnums=(0,)) def _render_igloo_blocks(self, raster, state): diff --git a/src/jaxatari/games/mods/beamrider/beamrider_mod_plugins.py b/src/jaxatari/games/mods/beamrider/beamrider_mod_plugins.py index 471228c7a..2026f1e80 100644 --- a/src/jaxatari/games/mods/beamrider/beamrider_mod_plugins.py +++ b/src/jaxatari/games/mods/beamrider/beamrider_mod_plugins.py @@ -6,6 +6,7 @@ from jaxatari.modification import JaxAtariInternalModPlugin from jaxatari.games.jax_beamrider import ( BLUE_LINE_INIT_TABLE, + BeamriderState, LaneBlockerState, WhiteUFOUpdate, WhiteUFOPattern, @@ -3621,3 +3622,34 @@ def stay_on_top(_): pattern_timer, new_key, ) + + +class DontKillMod(JaxAtariInternalModPlugin): + """Internal mod that punishes killing and shooting.""" + + constants_overrides = { + "UFO_REWARD": -100, + "UFO_SECTOR_REWARD": 0, + "BOUNCER_REWARD": -100, + "REJUVENATOR_REWARD": -100, + "MOTHERSHIP_REWARD": -100, + "MOTHERSHIP_SECTOR_REWARD": 0, + "MOTHERSHIP_LIFE_REWARD": 0, + "MOTHERSHIP_LIFE_SECTOR_REWARD": 0, + } + + @partial(jax.jit, static_argnums=(0,)) + def _get_reward(self, previous_state: BeamriderState, state: BeamriderState): + # Standard reward (which now includes the negative kill points from overrides) + reward = state.score - previous_state.score + + # Punish shooting + # A shot is fired when player_shot_frame transitions from -1 to >= 0 + shot_fired = jnp.logical_and( + previous_state.level.player_shot_frame == -1, + state.level.player_shot_frame != -1 + ) + # We use a penalty for every shot fired + shooting_penalty = jnp.where(shot_fired, 10.0, 0.0) + + return reward.astype(jnp.float32) - shooting_penalty diff --git a/src/jaxatari/games/mods/beamrider_mods.py b/src/jaxatari/games/mods/beamrider_mods.py index 6517270f5..6a8bf654b 100644 --- a/src/jaxatari/games/mods/beamrider_mods.py +++ b/src/jaxatari/games/mods/beamrider_mods.py @@ -1,5 +1,6 @@ from jaxatari.modification import JaxAtariModController from jaxatari.games.mods.beamrider.beamrider_mod_plugins import ( + DontKillMod, DoubleEnemySpeedMod, FogOfWarMod, HardcoreMod, @@ -17,6 +18,7 @@ class BeamRiderEnvMod(JaxAtariModController): """ REGISTRY = { + "dont_kill": DontKillMod, "double_enemy_speed": DoubleEnemySpeedMod, "fog_of_war": FogOfWarMod, "hardcore": HardcoreMod, diff --git a/src/jaxatari/games/mods/freeway/freeway_mod_plugins.py b/src/jaxatari/games/mods/freeway/freeway_mod_plugins.py index 10b91e09f..913129241 100644 --- a/src/jaxatari/games/mods/freeway/freeway_mod_plugins.py +++ b/src/jaxatari/games/mods/freeway/freeway_mod_plugins.py @@ -232,3 +232,29 @@ class GreenScoreMod(JaxAtariInternalModPlugin): 'data': _recolored_score } } + +class TooClosePenaltyMod(JaxAtariInternalModPlugin): + """Provides a -1 reward if the chicken is too close to a car on the same line.""" + @partial(jax.jit, static_argnums=(0,)) + def _get_reward(self, previous_state: FreewayState, state: FreewayState): + base_reward = state.score - previous_state.score + + chicken_x = self._env.consts.chicken_x + chicken_y_top = state.chicken_y - self._env.consts.chicken_height + chicken_y_bottom = state.chicken_y + + cars_x = state.cars[:, 0] + cars_y_top = state.cars[:, 1] - self._env.consts.car_height + cars_y_bottom = state.cars[:, 1] + + y_overlap = jnp.logical_and( + chicken_y_top < cars_y_bottom, + chicken_y_bottom > cars_y_top + ) + + x_close = jnp.abs(chicken_x - cars_x) < 10 + + too_close_any = jnp.any(jnp.logical_and(y_overlap, x_close)) + penalty = jnp.where(too_close_any, 1.0, 0.0) + + return base_reward.astype(jnp.float32) - penalty diff --git a/src/jaxatari/games/mods/freeway_mods.py b/src/jaxatari/games/mods/freeway_mods.py index 0fc50577e..3e5390cf7 100644 --- a/src/jaxatari/games/mods/freeway_mods.py +++ b/src/jaxatari/games/mods/freeway_mods.py @@ -1,6 +1,6 @@ import os from jaxatari.modification import JaxAtariModController -from jaxatari.games.mods.freeway.freeway_mod_plugins import StopAllCarsMod, StaticCarsMod, SlowCarsMod, BlackCarsMod, CenterCarsOnResetMod, InvertSpeed, HallOfFameMod, BikesMod, FrogMod, NewLaneColorsMod, GreenScoreMod +from jaxatari.games.mods.freeway.freeway_mod_plugins import StopAllCarsMod, StaticCarsMod, SlowCarsMod, BlackCarsMod, CenterCarsOnResetMod, InvertSpeed, HallOfFameMod, BikesMod, FrogMod, NewLaneColorsMod, GreenScoreMod, TooClosePenaltyMod class FreewayEnvMod(JaxAtariModController): """ @@ -21,6 +21,7 @@ class FreewayEnvMod(JaxAtariModController): "frog": FrogMod, "new_lane_colors": NewLaneColorsMod, "green_score": GreenScoreMod, + "too_close_penalty": TooClosePenaltyMod, "change_sprites": ["frog", "bikes", "new_lane_colors", "green_score"], } diff --git a/src/jaxatari/games/mods/frostbite/frostbite_mod_plugins.py b/src/jaxatari/games/mods/frostbite/frostbite_mod_plugins.py index 0269adc03..0098950ae 100644 --- a/src/jaxatari/games/mods/frostbite/frostbite_mod_plugins.py +++ b/src/jaxatari/games/mods/frostbite/frostbite_mod_plugins.py @@ -188,3 +188,11 @@ class DarkNightMod(JaxAtariInternalModPlugin): "RGB_NIGHT": (20, 20, 60), # Dark blue sky "DRAW_SHORE_LINE": True, } + +class BadFishesRewardMod(JaxAtariInternalModPlugin): + """ + Punishes the player for collecting fish by granting negative points. + """ + constants_overrides = { + "REWARD_FISH": -200, + } diff --git a/src/jaxatari/games/mods/frostbite_mods.py b/src/jaxatari/games/mods/frostbite_mods.py index fffe84910..9a6c4f284 100644 --- a/src/jaxatari/games/mods/frostbite_mods.py +++ b/src/jaxatari/games/mods/frostbite_mods.py @@ -2,7 +2,7 @@ from jaxatari.modification import JaxAtariModController from jaxatari.games.mods.frostbite.frostbite_mod_plugins import ( NoEnemiesMod, LightBlueIceMod, _StaticIceMod, _MisalignedIceMod, _AlignedIceMod, RecoloredObstaclesMod, TigerMod, - WhiteIglooMod, LeftIglooMod, EarlyBearMod, DarkNightMod + WhiteIglooMod, LeftIglooMod, EarlyBearMod, DarkNightMod, BadFishesRewardMod ) # --- The Registry --- @@ -15,6 +15,7 @@ "left_igloo": LeftIglooMod, "early_bear": EarlyBearMod, "dark_night": DarkNightMod, + "bad_fishes_reward": BadFishesRewardMod, "_static_ice": _StaticIceMod, "_misaligned_ice": _MisalignedIceMod, "_aligned_ice": _AlignedIceMod, From 5ca7f96b6e6213d3bf7deb2ccb6b1938e0087a3a Mon Sep 17 00:00:00 2001 From: Quentin Delfosse Date: Wed, 6 May 2026 00:41:48 +0200 Subject: [PATCH 23/28] Various mods implemented 2 --- src/jaxatari/games/jax_beamrider.py | 9 ++++-- src/jaxatari/games/jax_phoenix.py | 28 ++++++++++++++++-- src/jaxatari/games/jax_pong.py | 9 +++++- .../mods/beamrider/beamrider_mod_plugins.py | 29 +++++++++++-------- .../games/mods/phoenix/phoenix_mod_plugins.py | 10 +++++++ src/jaxatari/games/mods/phoenix_mods.py | 2 ++ .../games/mods/pong/pong_mod_plugins.py | 7 ++++- src/jaxatari/games/mods/pong_mods.py | 3 +- 8 files changed, 76 insertions(+), 21 deletions(-) diff --git a/src/jaxatari/games/jax_beamrider.py b/src/jaxatari/games/jax_beamrider.py index 9699bca0f..6ad76e3d7 100644 --- a/src/jaxatari/games/jax_beamrider.py +++ b/src/jaxatari/games/jax_beamrider.py @@ -4623,10 +4623,13 @@ def _render_score(self, raster, state): # Add negative sign if score < 0 is_negative = state.score < 0 digit_mask = score_masks[0] - minus_color_id = jnp.max(digit_mask) - minus_mask = jnp.zeros_like(digit_mask) + transparent_id = self.jr.TRANSPARENT_ID + # Get the actual color ID of the digit + digit_color = jnp.where(digit_mask != transparent_id, digit_mask, jnp.inf).min().astype(digit_mask.dtype) + + minus_mask = jnp.full_like(digit_mask, transparent_id) mid_y = digit_mask.shape[0] // 2 - minus_mask = minus_mask.at[mid_y, 1:-1].set(minus_color_id) + minus_mask = minus_mask.at[mid_y, 1:-1].set(digit_color) raster = jax.lax.cond( is_negative, diff --git a/src/jaxatari/games/jax_phoenix.py b/src/jaxatari/games/jax_phoenix.py index 802ae6de3..1528869f1 100644 --- a/src/jaxatari/games/jax_phoenix.py +++ b/src/jaxatari/games/jax_phoenix.py @@ -168,6 +168,7 @@ class PhoenixConstants(AutoDerivedConstants): PLAYER_PROJECTILE_INITIAL_OFFSET: int = struct.field(pytree_node=False, default=-5) RESET_START_LEVEL: int = struct.field(pytree_node=False, default=1) ENEMY_PROJECTILE_SPEED: int = struct.field(pytree_node=False, default=2) + SHOOTING_REWARD: int = struct.field(pytree_node=False, default=0) # --- Global / shared enemy timing and odds --- ENEMY_DEATH_DURATION: int = struct.field(pytree_node=False, default=30) # ca. 0,25 Sekunden bei 30 FPS @@ -1958,8 +1959,8 @@ def check_collision(entity_pos, projectile_pos): actual_hit_scores = jnp.where(enemy_collisions, enemy_hit_scores, 0) total_hit_score = jnp.sum(actual_hit_scores) - # Update overall score with sub_step (wings) + main kills - score = (state.score + sub_step_score + total_hit_score).astype(jnp.int32) + # Update overall score with sub_step (wings) + main kills + shooting reward + score = (state.score + sub_step_score + total_hit_score + jnp.where(firing, self.consts.SHOOTING_REWARD, 0)).astype(jnp.int32) # Gegner entfernen nach Ablauf der jeweiligen Death-Animation death_done_any = jnp.where(is_bat_level, b_death_done, p_death_done) @@ -2852,7 +2853,11 @@ def _render_ui(self, state, raster): score_y = 10 + score_dy digit_masks = self.get_mask('digits') digit_w = digit_masks[0].shape[1] - score_digits = self.jr.int_to_digits(state.score, max_digits=max_digits) + + # Handle negative score + abs_score = jnp.abs(state.score.astype(jnp.int32)) + score_digits = self.jr.int_to_digits(abs_score, max_digits=max_digits) + has_nonzero = jnp.any(score_digits != 0) first_idx = jnp.where(has_nonzero, jnp.argmax(score_digits != 0), max_digits - 1) num_to_render = jnp.where(has_nonzero, max_digits - first_idx, 1) @@ -2860,6 +2865,23 @@ def _render_ui(self, state, raster): field_total_w = max_digits * spacing base_left = (self.consts.WIDTH - field_total_w) // 2 score_x = base_left + first_idx * spacing + score_dx + + # Render minus sign if negative + is_negative = state.score < 0 + # Use a pixel from digit '0' that is likely to be colored (top-center) + digit_color_id = digit_masks[0][0, digit_w // 2] + raster = jax.lax.cond( + is_negative, + lambda r: self.jr.draw_rects( + r, + positions=jnp.array([[score_x - 7, score_y + 3]]), + sizes=jnp.array([[5, 1]]), + color_id=digit_color_id.astype(r.dtype) + ), + lambda r: r, + raster + ) + raster = self.jr.render_label_selective( raster, score_x, score_y, score_digits, digit_masks, diff --git a/src/jaxatari/games/jax_pong.py b/src/jaxatari/games/jax_pong.py index 87a47ce08..91ad14fac 100644 --- a/src/jaxatari/games/jax_pong.py +++ b/src/jaxatari/games/jax_pong.py @@ -65,6 +65,7 @@ class PongConstants(struct.PyTreeNode): PADDLE_MIN_Y: float = struct.field(pytree_node=False, default=24.0) PADDLE_MAX_Y: float = struct.field(pytree_node=False, default=190.0) PADDLE_DAMPENING_Y: float = struct.field(pytree_node=False, default=170.0) + ACCELERATION_REWARD: float = struct.field(pytree_node=False, default=0.0) class PongState(struct.PyTreeNode): @@ -481,9 +482,15 @@ def _get_info(self, state: PongState, ) -> PongInfo: @partial(jax.jit, static_argnums=(0,)) def _get_reward(self, previous_state: PongState, state: PongState): - return (state.player_score - state.enemy_score) - ( + score_reward = (state.player_score - state.enemy_score) - ( previous_state.player_score - previous_state.enemy_score ) + # Acceleration detection: magnitude increased and sign flipped from positive to negative + accelerated = jnp.logical_and( + jnp.abs(state.ball_vel_x) > jnp.abs(previous_state.ball_vel_x), + jnp.logical_and(previous_state.ball_vel_x > 0, state.ball_vel_x < 0) + ) + return score_reward + jnp.where(accelerated, self.consts.ACCELERATION_REWARD, 0.0) @partial(jax.jit, static_argnums=(0,)) def _get_done(self, state: PongState) -> bool: diff --git a/src/jaxatari/games/mods/beamrider/beamrider_mod_plugins.py b/src/jaxatari/games/mods/beamrider/beamrider_mod_plugins.py index 2026f1e80..a85664344 100644 --- a/src/jaxatari/games/mods/beamrider/beamrider_mod_plugins.py +++ b/src/jaxatari/games/mods/beamrider/beamrider_mod_plugins.py @@ -3,7 +3,7 @@ import jax import jax.numpy as jnp -from jaxatari.modification import JaxAtariInternalModPlugin +from jaxatari.modification import JaxAtariInternalModPlugin, JaxAtariPostStepModPlugin from jaxatari.games.jax_beamrider import ( BLUE_LINE_INIT_TABLE, BeamriderState, @@ -3624,7 +3624,7 @@ def stay_on_top(_): ) -class DontKillMod(JaxAtariInternalModPlugin): +class DontKillMod(JaxAtariInternalModPlugin, JaxAtariPostStepModPlugin): """Internal mod that punishes killing and shooting.""" constants_overrides = { @@ -3639,17 +3639,22 @@ class DontKillMod(JaxAtariInternalModPlugin): } @partial(jax.jit, static_argnums=(0,)) - def _get_reward(self, previous_state: BeamriderState, state: BeamriderState): - # Standard reward (which now includes the negative kill points from overrides) - reward = state.score - previous_state.score - - # Punish shooting + def run(self, prev_state: BeamriderState, new_state: BeamriderState) -> BeamriderState: # A shot is fired when player_shot_frame transitions from -1 to >= 0 shot_fired = jnp.logical_and( - previous_state.level.player_shot_frame == -1, - state.level.player_shot_frame != -1 + prev_state.level.player_shot_frame == -1, + new_state.level.player_shot_frame != -1 ) - # We use a penalty for every shot fired - shooting_penalty = jnp.where(shot_fired, 10.0, 0.0) + shooting_penalty = jnp.where(shot_fired, 10, 0).astype(new_state.score.dtype) - return reward.astype(jnp.float32) - shooting_penalty + # We also need to add survival reward? Not requested. + new_score = new_state.score - shooting_penalty + return new_state.replace(score=new_score) + + @partial(jax.jit, static_argnums=(0,)) + def _get_reward(self, previous_state: BeamriderState, state: BeamriderState): + # We return the difference in score. + # Since we modify the score directly in `run`, `state.score` already includes the penalty. + # So reward is just the difference in score! + reward = state.score - previous_state.score + return reward.astype(jnp.float32) diff --git a/src/jaxatari/games/mods/phoenix/phoenix_mod_plugins.py b/src/jaxatari/games/mods/phoenix/phoenix_mod_plugins.py index 7eaad0bed..a15e9538b 100644 --- a/src/jaxatari/games/mods/phoenix/phoenix_mod_plugins.py +++ b/src/jaxatari/games/mods/phoenix/phoenix_mod_plugins.py @@ -121,6 +121,16 @@ class MatrixMod(JaxAtariInternalModPlugin): 'RGB_BATS_RED': (50, 255, 50), } +class PenalizeShootMod(JaxAtariInternalModPlugin): + """ + Penalize (-50 points) for each time the player shoots. + """ + + constants_overrides = { + "SHOOTING_REWARD": -50, + } + + class BloodMoonMod(JaxAtariInternalModPlugin): """A dark red themed mod.""" name = "blood_moon" diff --git a/src/jaxatari/games/mods/phoenix_mods.py b/src/jaxatari/games/mods/phoenix_mods.py index d986f0777..3acc6461d 100644 --- a/src/jaxatari/games/mods/phoenix_mods.py +++ b/src/jaxatari/games/mods/phoenix_mods.py @@ -6,6 +6,7 @@ InvinciblePlayerMod, FastEnemyBulletsMod, NoAbilityCooldownMod, + PenalizeShootMod, NightMod, GrayscaleMod, InvertedColorsMod, @@ -26,6 +27,7 @@ class PhoenixEnvMod(JaxAtariModController): "invincible_player": InvinciblePlayerMod, "fast_enemy_bullets": FastEnemyBulletsMod, "no_ability_cooldown": NoAbilityCooldownMod, + "penalize_shoot": PenalizeShootMod, "night_mode": NightMod, "grayscale": GrayscaleMod, "inverted_colors": InvertedColorsMod, diff --git a/src/jaxatari/games/mods/pong/pong_mod_plugins.py b/src/jaxatari/games/mods/pong/pong_mod_plugins.py index b6c053141..cd9f3c4ba 100644 --- a/src/jaxatari/games/mods/pong/pong_mod_plugins.py +++ b/src/jaxatari/games/mods/pong/pong_mod_plugins.py @@ -162,4 +162,9 @@ class ChangePlayerColorMod(JaxAtariInternalModPlugin): "type": "single", "data": _recolor_sprite("player.npy", (92, 186, 92), _NEW_PLAYER_COLOR), } - } \ No newline at end of file + } + + +class RewardAccelerationMod(JaxAtariInternalModPlugin): + """Rewards (+2) the player each time the ball is accelerated by the player.""" + constants_overrides = {"ACCELERATION_REWARD": 2.0} \ No newline at end of file diff --git a/src/jaxatari/games/mods/pong_mods.py b/src/jaxatari/games/mods/pong_mods.py index c0ccc2009..f4694ea5e 100644 --- a/src/jaxatari/games/mods/pong_mods.py +++ b/src/jaxatari/games/mods/pong_mods.py @@ -1,6 +1,6 @@ import os from jaxatari.modification import JaxAtariModController -from jaxatari.games.mods.pong.pong_mod_plugins import LazyEnemyMod, RandomEnemyMod, AlwaysZeroScoreMod, LinearMovementMod, ShiftPlayerMod, ShiftEnemyMod, NoFireMod, ChangeBackgroundColorMod, ChangePlayerColorMod +from jaxatari.games.mods.pong.pong_mod_plugins import LazyEnemyMod, RandomEnemyMod, AlwaysZeroScoreMod, LinearMovementMod, ShiftPlayerMod, ShiftEnemyMod, NoFireMod, ChangeBackgroundColorMod, ChangePlayerColorMod, RewardAccelerationMod class PongEnvMod(JaxAtariModController): """ @@ -18,6 +18,7 @@ class PongEnvMod(JaxAtariModController): "no_fire": NoFireMod, "change_background_color": ChangeBackgroundColorMod, "change_player_color": ChangePlayerColorMod, + "reward_acceleration": RewardAccelerationMod, } _mod_sprite_dir = os.path.join(os.path.dirname(__file__), "pong", "sprites") From 8f7e46adf2601daf29d5772a3fe0940e73d5521d Mon Sep 17 00:00:00 2001 From: Quentin Delfosse Date: Wed, 6 May 2026 00:50:25 +0200 Subject: [PATCH 24/28] Qbert reward mods --- src/jaxatari/games/jax_qbert.py | 22 ++++++++++++++- .../games/mods/qbert/qbert_mod_plugins.py | 28 ++++++++++++++++++- src/jaxatari/games/mods/qbert_mods.py | 4 ++- 3 files changed, 51 insertions(+), 3 deletions(-) diff --git a/src/jaxatari/games/jax_qbert.py b/src/jaxatari/games/jax_qbert.py index 8da4ce45b..0d09ed62c 100644 --- a/src/jaxatari/games/jax_qbert.py +++ b/src/jaxatari/games/jax_qbert.py @@ -1430,11 +1430,31 @@ def get_mask(key): raster = self._draw_colors(raster, state, pyra) # Vectorized Score - player_score_digits = jr.int_to_digits(state.player_score, max_digits=5) + abs_score = jnp.abs(state.player_score) + player_score_digits = jr.int_to_digits(abs_score, max_digits=5) digit_idx = self.SCORE_MAP score_pixel_id = M['score_digits'][player_score_digits[digit_idx], self.SCORE_LOCAL_Y, self.SCORE_LOCAL_X] raster = jnp.where((digit_idx != -1) & (state.player_position[1] >= 3), score_pixel_id, raster) + # Draw minus sign if score is negative + def draw_minus(r): + # Use a simple horizontal line as a minus sign + # Digit color can be extracted from digit 0 mask + transparent_id = self.jr.TRANSPARENT_ID + mask_0 = M['score_digits'][0] + digit_color = jnp.where(mask_0 != transparent_id, mask_0, jnp.inf).min().astype(mask_0.dtype) + # Create a 5x7 mask for the minus sign + minus_mask = jnp.full((7, 5), transparent_id, dtype=mask_0.dtype) + minus_mask = minus_mask.at[3, 0:4].set(digit_color) + return self.jr.render_at(r, int(round(26 * self.config.width_scaling)), 6, minus_mask) + + raster = jax.lax.cond( + state.player_score < 0, + draw_minus, + lambda r: r, + raster + ) + # Vectorized Lives live_idx = self.LIVES_MAP live_pixel_id = M['qbert_live'][self.LIVES_LOCAL_Y, self.LIVES_LOCAL_X] diff --git a/src/jaxatari/games/mods/qbert/qbert_mod_plugins.py b/src/jaxatari/games/mods/qbert/qbert_mod_plugins.py index b97fe6e03..e2eba8616 100644 --- a/src/jaxatari/games/mods/qbert/qbert_mod_plugins.py +++ b/src/jaxatari/games/mods/qbert/qbert_mod_plugins.py @@ -436,9 +436,25 @@ def render(self, state: QbertState) -> jnp.ndarray: init_val=raster, ) - player_score_digits = jr.int_to_digits(state.player_score, max_digits=5) + player_score_digits = jr.int_to_digits(jnp.abs(state.player_score), max_digits=5) raster = jnp.where(state.player_position[1] >= 3, jr.render_label_selective(raster, 34, 6, player_score_digits, M['score_digits'], 0, 5, spacing=8, max_digits_to_render=5), raster) + # Draw minus sign if score is negative + def draw_minus(r): + transparent_id = self._env.renderer.jr.TRANSPARENT_ID + mask_0 = M['score_digits'][0] + digit_color = jnp.where(mask_0 != transparent_id, mask_0, jnp.inf).min().astype(mask_0.dtype) + minus_mask = jnp.full((7, 5), transparent_id, dtype=mask_0.dtype) + minus_mask = minus_mask.at[3, 0:4].set(digit_color) + return self._env.renderer.jr.render_at(r, int(round(26 * self._env.renderer.config.width_scaling)), 6, minus_mask) + + raster = jax.lax.cond( + state.player_score < 0, + draw_minus, + lambda r: r, + raster + ) + raster = jax.lax.fori_loop( lower=0, upper=state.lives, @@ -671,3 +687,13 @@ class RedCoilyMod(JaxAtariInternalModPlugin): "RGB_COILY": (173, 5, 64), } +class PenalizeAllCollectablesMod(JaxAtariInternalModPlugin): + """ + Penalizes collecting items that normally give positive rewards. + """ + constants_overrides = { + "GREEN_BALL_REWARD": -100, + "SAM_REWARD": -300, + "COILY_REWARD": -500, + } + diff --git a/src/jaxatari/games/mods/qbert_mods.py b/src/jaxatari/games/mods/qbert_mods.py index f5e4d2dfd..a03ef4f01 100644 --- a/src/jaxatari/games/mods/qbert_mods.py +++ b/src/jaxatari/games/mods/qbert_mods.py @@ -15,7 +15,8 @@ GrayscaleMod, InvertedColorsMod, SwapCollectiblesEnemiesMod, - RedCoilyMod + RedCoilyMod, + PenalizeAllCollectablesMod ) class QbertEnvMod(JaxAtariModController): @@ -39,6 +40,7 @@ class QbertEnvMod(JaxAtariModController): "inverted_colors": InvertedColorsMod, "swap_collectibles_enemies": SwapCollectiblesEnemiesMod, "red_coily": RedCoilyMod, + "penalize_all_collectables": PenalizeAllCollectablesMod, } def __init__(self, From 2a5cfa6cdb38992eddad03620db19f2a27da9747 Mon Sep 17 00:00:00 2001 From: Quentin Delfosse Date: Wed, 6 May 2026 01:10:37 +0200 Subject: [PATCH 25/28] Various mods implemented 3 --- src/jaxatari/games/jax_enduro.py | 13 ++++++- src/jaxatari/games/jax_gravitar.py | 9 +++-- src/jaxatari/games/jax_spaceinvaders.py | 36 +++++++++++++++++-- .../games/mods/enduro/enduro_mod_plugins.py | 16 +++++++++ src/jaxatari/games/mods/enduro_mods.py | 4 ++- .../mods/gravitar/gravitar_mod_plugins.py | 9 +++++ src/jaxatari/games/mods/gravitar_mods.py | 2 ++ .../spaceinvaders_mod_plugins.py | 10 ++++++ src/jaxatari/games/mods/spaceinvaders_mods.py | 4 ++- 9 files changed, 96 insertions(+), 7 deletions(-) diff --git a/src/jaxatari/games/jax_enduro.py b/src/jaxatari/games/jax_enduro.py index a77047ac1..404fe123e 100644 --- a/src/jaxatari/games/jax_enduro.py +++ b/src/jaxatari/games/jax_enduro.py @@ -290,6 +290,8 @@ class EnduroConstants(AutoDerivedConstants): secondary_collision_overlap_exception_max: int = struct.field(pytree_node=False, default=2) initial_position: int = struct.field(pytree_node=False, default=200) next_day_car_position: int = struct.field(pytree_node=False, default=300) + max_speed_reward: float = struct.field(pytree_node=False, default=0.0) + disable_reward_at_max_speed: bool = struct.field(pytree_node=False, default=False) # Difficulty Scaling (opponent_speed scaled from 24→2, so increments scale by 2/24) start_level: int = struct.field(pytree_node=False, default=1) @@ -1629,7 +1631,16 @@ def _get_reward(self, state: EnduroGameState, new_state: EnduroGameState) -> flo day_rolled_over = new_state.day_count > state.day_count cars_overtaken = (state.cars_to_pass - new_state.cars_to_pass).astype(jnp.float32) cars_overtaken = jnp.where(day_rolled_over, 0.0, cars_overtaken) - return cars_overtaken + + # Max speed reward + is_max_speed = new_state.player_speed >= self.consts.max_speed + is_20th_frame = (new_state.step_count % 20) == 0 + max_speed_reward = jnp.where(is_max_speed & is_20th_frame, self.consts.max_speed_reward, 0.0) + + total_reward = cars_overtaken + max_speed_reward + total_reward = jnp.where(self.consts.disable_reward_at_max_speed & is_max_speed, 0.0, total_reward) + + return total_reward @partial(jax.jit, static_argnums=(0,)) def _get_done(self, state: EnduroGameState) -> bool: diff --git a/src/jaxatari/games/jax_gravitar.py b/src/jaxatari/games/jax_gravitar.py index 31f7ffda3..a2e50834e 100644 --- a/src/jaxatari/games/jax_gravitar.py +++ b/src/jaxatari/games/jax_gravitar.py @@ -195,6 +195,7 @@ class GravitarConstants(struct.PyTreeNode): LEVEL_CLEAR_SCORE: float = struct.field(pytree_node=False, default=1000.0) UFO_KILL_SCORE: float = struct.field(pytree_node=False, default=100.0) SAUCER_KILL_SCORE: float = struct.field(pytree_node=False, default=100.0) + FIRE_REWARD: float = struct.field(pytree_node=False, default=0.0) # Bonuses SOLAR_SYSTEM_BONUS_FUEL: float = struct.field(pytree_node=False, default=7000.0) @@ -705,6 +706,7 @@ class EnvState: level_clear_score: jnp.ndarray # float32 ufo_kill_score: jnp.ndarray # float32 saucer_kill_score: jnp.ndarray # float32 + fire_reward: jnp.ndarray # float32 thrust_power: jnp.ndarray # float32 (unscaled; divided by WORLD_SCALE in physics) max_speed: jnp.ndarray # float32 (unscaled; divided by WORLD_SCALE in physics) prev_action: jnp.ndarray # int32, previous action taken @@ -1573,6 +1575,7 @@ def create_env_state(rng: jnp.ndarray) -> EnvState: level_clear_score=jnp.float32(LEVEL_CLEAR_SCORE), ufo_kill_score=jnp.float32(UFO_KILL_SCORE), saucer_kill_score=jnp.float32(SAUCER_KILL_SCORE), + fire_reward=jnp.float32(_DEFAULT_CONSTS.FIRE_REWARD), prev_action=jnp.int32(0), ) @@ -2759,7 +2762,8 @@ def _unified_game_loop(state, act): # Level: lives decrease on death_event; map/arena: handled externally lives_after_death = state_after_spawn.lives - jnp.where(death_event_level, 1, 0) score_before = state_after_spawn.score - level_score_after = score_before + level_score_delta + score_from_fire = jnp.where(can_fire_player, state_after_spawn.fire_reward, 0.0) + level_score_after = score_before + level_score_delta + score_from_fire bonus_life_crossed = (level_score_after // 10000) > (score_before // 10000) lives_gained_from_score = jnp.where(bonus_life_crossed & is_level, 1, 0) final_lives = lives_after_death + lives_gained_from_score @@ -2767,7 +2771,7 @@ def _unified_game_loop(state, act): # Map/arena: saucer reward; level: turret/clear/UFO rewards map_score_delta = reward_saucer - score_delta = jnp.where(is_level, level_score_delta, map_score_delta) + score_delta = jnp.where(is_level, level_score_delta, map_score_delta) + score_from_fire fuel_next = jnp.maximum(0.0, state_after_spawn.fuel - fuel_consumed + fuel_gained) # mode_timer: map increments from map_mode_timer computed above; arena also increments; level increments @@ -3452,6 +3456,7 @@ def reset_map(self, key: jnp.ndarray, level_clear_score=jnp.float32(self.consts.LEVEL_CLEAR_SCORE), ufo_kill_score=jnp.float32(self.consts.UFO_KILL_SCORE), saucer_kill_score=jnp.float32(self.consts.SAUCER_KILL_SCORE), + fire_reward=jnp.float32(self.consts.FIRE_REWARD), prev_action=jnp.int32(0), ) diff --git a/src/jaxatari/games/jax_spaceinvaders.py b/src/jaxatari/games/jax_spaceinvaders.py index 68a591247..cf9125642 100644 --- a/src/jaxatari/games/jax_spaceinvaders.py +++ b/src/jaxatari/games/jax_spaceinvaders.py @@ -96,6 +96,8 @@ class SpaceInvadersConstants(AutoDerivedConstants): POSITION_LIFE_X: int = struct.field(pytree_node=False, default=83) + SHOOTING_REWARD: int = struct.field(pytree_node=False, default=0) + # Thresholds: [15, 29, 33, 34, 35] (Total destroyed count) MOVEMENT_THRESHOLDS: jnp.ndarray = struct.field(pytree_node=False, default_factory=lambda: jnp.array([15, 29, 33, 34, 35], dtype=jnp.int32)) # Rates: Delays in frames. @@ -642,6 +644,8 @@ def step_running(self, state: SpaceInvadersState, action: chex.Array, key) -> Sp new_bullet_active, new_bullet_x, new_bullet_y = self._player_bullet_step(state, action) + actually_fired = new_bullet_active & ~state.bullet_active + new_bullet_state = state.replace( bullet_active=new_bullet_active, bullet_x=new_bullet_x, @@ -650,6 +654,8 @@ def step_running(self, state: SpaceInvadersState, action: chex.Array, key) -> Sp new_destroyed, new_score, final_bullet_active, new_ufo_state = self._check_bullet_enemy_collisions(new_bullet_state) + new_score = new_score + jnp.where(actually_fired, self.consts.SHOOTING_REWARD, 0) + def find_bounds(has_occupied, limit): # Identify first and last indices first = jnp.argmax(has_occupied) @@ -1217,7 +1223,8 @@ def render_life(self, state: SpaceInvadersState, raster): @partial(jax.jit, static_argnums=(0,)) def render_score(self, state: SpaceInvadersState, raster): - digits = self.jr.int_to_digits(state.player_score, max_digits=4) + abs_score = jnp.abs(state.player_score.astype(jnp.int32)) + digits = self.jr.int_to_digits(abs_score, max_digits=4) score_pos_x = jnp.array([3, 6, 9, 12], dtype=jnp.int32) score_pos_y = jnp.array([9, 10, 10, 9], dtype=jnp.int32) yellow_zero_mask = self.SHAPE_MASKS['zero_yellow'] @@ -1231,7 +1238,32 @@ def render_digit_and_zero(i, r): yy = score_pos_y[i] return self.jr.render_at(r, xy, yy, yellow_zero_mask) - return jax.lax.fori_loop(0, 4, render_digit_and_zero, raster) + raster = jax.lax.fori_loop(0, 4, render_digit_and_zero, raster) + + # Handle minus sign + has_nonzero = jnp.any(digits != 0) + first_idx = jnp.where(has_nonzero, jnp.argmax(digits != 0), 3) + is_negative = state.player_score < 0 + + # Use a pixel from digit '0' (green) for the minus sign color + digit_color_id = self.SHAPE_MASKS['digits_green'][0][0, self.consts.NUMBER_SIZE[0] // 2] + + minus_x = score_pos_x[first_idx] + first_idx * self.consts.NUMBER_SIZE[0] - 6 + minus_y = score_pos_y[first_idx] + 4 + + raster = jax.lax.cond( + is_negative, + lambda r: self.jr.draw_rects( + r, + positions=jnp.array([[minus_x, minus_y]]), + sizes=jnp.array([[4, 1]]), + color_id=digit_color_id.astype(r.dtype) + ), + lambda r: r, + raster + ) + + return raster @partial(jax.jit, static_argnums=(0,)) def render_defense(self, state: SpaceInvadersState, raster): diff --git a/src/jaxatari/games/mods/enduro/enduro_mod_plugins.py b/src/jaxatari/games/mods/enduro/enduro_mod_plugins.py index 986c30816..a45fbcb26 100644 --- a/src/jaxatari/games/mods/enduro/enduro_mod_plugins.py +++ b/src/jaxatari/games/mods/enduro/enduro_mod_plugins.py @@ -214,3 +214,19 @@ class NoOpponentsMod(JaxAtariInternalModPlugin): "opponent_density": 0.0, "opponent_density_increment": 0.0 } + +class MaxSpeedRewardMod(JaxAtariInternalModPlugin): + """ + Provides +2 every 20th frame when at maximum speed. + """ + constants_overrides = { + "max_speed_reward": 2.0 + } + +class DontMaxSpeedMod(JaxAtariInternalModPlugin): + """ + Does not reward the player if the player is at maximum speed. + """ + constants_overrides = { + "disable_reward_at_max_speed": True + } diff --git a/src/jaxatari/games/mods/enduro_mods.py b/src/jaxatari/games/mods/enduro_mods.py index d2fb728d5..ea707c0c5 100644 --- a/src/jaxatari/games/mods/enduro_mods.py +++ b/src/jaxatari/games/mods/enduro_mods.py @@ -1,7 +1,7 @@ from jaxatari.modification import JaxAtariModController from jaxatari.games.mods.enduro.enduro_mod_plugins import SpeedAndXPosHudMod, StartInCurveMod, \ StartInMaxCurveMod, FilledRoadMod, SnowWeatherMod, NightWeatherMod, FogWeatherMod, DayWeatherMod, \ - SunsetWeatherMod, DawnWeatherMod, ShortDaysMod, NoOpponentsMod + SunsetWeatherMod, DawnWeatherMod, ShortDaysMod, NoOpponentsMod, MaxSpeedRewardMod, DontMaxSpeedMod class EnduroEnvMod(JaxAtariModController): """ @@ -20,6 +20,8 @@ class EnduroEnvMod(JaxAtariModController): "dawn": DawnWeatherMod, "short_days": ShortDaysMod, "no_opponents": NoOpponentsMod, + "max_speed_reward": MaxSpeedRewardMod, + "dont_max_speed": DontMaxSpeedMod, } def __init__(self, diff --git a/src/jaxatari/games/mods/gravitar/gravitar_mod_plugins.py b/src/jaxatari/games/mods/gravitar/gravitar_mod_plugins.py index 7243f639e..348556dd7 100644 --- a/src/jaxatari/games/mods/gravitar/gravitar_mod_plugins.py +++ b/src/jaxatari/games/mods/gravitar/gravitar_mod_plugins.py @@ -164,3 +164,12 @@ class InvertedColorsMod(JaxAtariInternalModPlugin): {"source": (84, 160, 197), "target": (171, 95, 58)}, ) } + + +class SharpshooterMod(JaxAtariInternalModPlugin): + """Punishes shooting (-100) and rewards killing enemies (+500).""" + + constants_overrides = { + "FIRE_REWARD": -100.0, + "ENEMY_KILL_SCORE": 500.0, + } diff --git a/src/jaxatari/games/mods/gravitar_mods.py b/src/jaxatari/games/mods/gravitar_mods.py index 6f109d0ef..9f959ddf5 100644 --- a/src/jaxatari/games/mods/gravitar_mods.py +++ b/src/jaxatari/games/mods/gravitar_mods.py @@ -15,6 +15,7 @@ RedAlertMod, GrayscaleMod, InvertedColorsMod, + SharpshooterMod, ) @@ -37,6 +38,7 @@ class GravitarEnvMod(JaxAtariModController): "red_alert": RedAlertMod, "grayscale": GrayscaleMod, "inverted_colors": InvertedColorsMod, + "sharpshooter": SharpshooterMod, } def __init__( diff --git a/src/jaxatari/games/mods/spaceinvaders/spaceinvaders_mod_plugins.py b/src/jaxatari/games/mods/spaceinvaders/spaceinvaders_mod_plugins.py index 45f2c6f99..e5ff7453b 100644 --- a/src/jaxatari/games/mods/spaceinvaders/spaceinvaders_mod_plugins.py +++ b/src/jaxatari/games/mods/spaceinvaders/spaceinvaders_mod_plugins.py @@ -106,3 +106,13 @@ def after_reset(self, obs, state): enemy_bullets_active=jnp.zeros_like(state.enemy_bullets_active, dtype=jnp.bool_) ) return self._env._get_observation(state), state + + +class PenalizeShootMod(JaxAtariInternalModPlugin): + """ + Penalize (-5 points) for each time the player shoots. + """ + + constants_overrides = { + "SHOOTING_REWARD": -5, + } diff --git a/src/jaxatari/games/mods/spaceinvaders_mods.py b/src/jaxatari/games/mods/spaceinvaders_mods.py index 50715db01..bb99b96eb 100644 --- a/src/jaxatari/games/mods/spaceinvaders_mods.py +++ b/src/jaxatari/games/mods/spaceinvaders_mods.py @@ -6,7 +6,8 @@ DisableShieldRightMod, ShiftShieldsMod, ControllableMissileMod, - NoDangerMod + NoDangerMod, + PenalizeShootMod ) class SpaceInvadersEnvMod(JaxAtariModController): @@ -21,6 +22,7 @@ class SpaceInvadersEnvMod(JaxAtariModController): "shift_shields": ShiftShieldsMod, "controllable_missile": ControllableMissileMod, "no_danger": NoDangerMod, + "penalize_shoot": PenalizeShootMod, } def __init__(self, From 4cb29d4ee87ed40590fb5950a2d00cdda395fdb9 Mon Sep 17 00:00:00 2001 From: Quentin Delfosse Date: Wed, 6 May 2026 18:08:38 +0200 Subject: [PATCH 26/28] Asteroids mod: dont_shoot --- src/jaxatari/core.py | 3 ++- src/jaxatari/games/jax_asteroids.py | 26 +++++++++++++++---- src/jaxatari/rendering/jax_rendering_utils.py | 12 ++++----- 3 files changed, 29 insertions(+), 12 deletions(-) diff --git a/src/jaxatari/core.py b/src/jaxatari/core.py index 2f9459393..14d83c7f1 100644 --- a/src/jaxatari/core.py +++ b/src/jaxatari/core.py @@ -105,7 +105,8 @@ def _warn_deprecated_obs_to_flat_array(env: JaxEnvironment) -> None: "spaceinvaders": "jaxatari.games.mods.spaceinvaders_mods.SpaceInvadersEnvMod", "skiing": "jaxatari.games.mods.skiing_mods.SkiingEnvMod", "alien": "jaxatari.games.mods.alien_mods.AlienEnvMod", - "timepilot": "jaxatari.games.mods.timepilot_mods.TimePilotEnvMod" + "timepilot": "jaxatari.games.mods.timepilot_mods.TimePilotEnvMod", + "asteroids": "jaxatari.games.mods.asteroids_mods.AsteroidsEnvMod" } diff --git a/src/jaxatari/games/jax_asteroids.py b/src/jaxatari/games/jax_asteroids.py index 5622ab159..2bbbc2333 100644 --- a/src/jaxatari/games/jax_asteroids.py +++ b/src/jaxatari/games/jax_asteroids.py @@ -16,8 +16,12 @@ def _create_static_procedural_sprites() -> dict: """Creates procedural sprites that don't depend on dynamic values.""" # Create a procedural sprite for the wall color to ensure it's in the palette wall_color_rgba = jnp.array([0, 0, 0, 255], dtype=jnp.uint8).reshape(1, 1, 4) + # Create a minus sign for negative scores + minus_sign = jnp.zeros((10, 12, 4), dtype=jnp.uint8) + minus_sign = minus_sign.at[4:6, 2:10, :].set(255) return { 'wall_color': wall_color_rgba, + 'minus_sign': minus_sign, } def _get_default_asset_config() -> tuple: @@ -1347,13 +1351,25 @@ def _draw(ras): wall_sizes = jnp.array([[self.consts.WIDTH, self.consts.WALL_TOP_HEIGHT], [self.consts.WIDTH, self.consts.WALL_BOTTOM_HEIGHT]]) raster = self.jr.draw_rects(raster, wall_positions, wall_sizes, wall_color_id) def _get_number_of_digits(val): - return jax.lax.cond(val < 10, lambda: 1, lambda: - jax.lax.cond(val < 100, lambda: 2, lambda: - jax.lax.cond(val < 1000, lambda: 3, lambda: - jax.lax.cond(val < 10000, lambda: 4, lambda: 5)))) + val_abs = jnp.abs(val) + return jax.lax.cond(val_abs < 10, lambda: 1, lambda: + jax.lax.cond(val_abs < 100, lambda: 2, lambda: + jax.lax.cond(val_abs < 1000, lambda: 3, lambda: + jax.lax.cond(val_abs < 10000, lambda: 4, lambda: 5)))) score_digits_arr = self.jr.int_to_digits(state.score, max_digits=5) num_score_digits = _get_number_of_digits(state.score) - raster = self.jr.render_label_selective(raster, 68 - 16 * (num_score_digits - 1), 5, + + # Render minus sign if negative + is_negative = state.score < 0 + score_x = 68 - 16 * (num_score_digits - 1) + raster = jax.lax.cond( + is_negative, + lambda r: self.jr.render_at(r, score_x - 16, 5, self.SHAPE_MASKS['minus_sign']), + lambda r: r, + raster + ) + + raster = self.jr.render_label_selective(raster, score_x, 5, score_digits_arr, self.SHAPE_MASKS['digits'], 5 - num_score_digits, num_score_digits, spacing=16, max_digits_to_render=5) lives_digits_arr = self.jr.int_to_digits(state.lives, max_digits=1) diff --git a/src/jaxatari/rendering/jax_rendering_utils.py b/src/jaxatari/rendering/jax_rendering_utils.py index 01211fea2..cd96f5bbe 100644 --- a/src/jaxatari/rendering/jax_rendering_utils.py +++ b/src/jaxatari/rendering/jax_rendering_utils.py @@ -1329,13 +1329,13 @@ def apply_updates(p): @partial(jax.jit, static_argnames=["max_digits", "self"]) def int_to_digits(self, n, max_digits=8): """ - Convert a non-negative integer or a batch of integers to a fixed-length - JAX array of digits. Handles both scalar and batched inputs. + Convert an integer or a batch of integers to a fixed-length + JAX array of digits (using absolute value). """ - # This logic works whether 'n' is a scalar or a batched array. - n = jnp.maximum(n, 0) + # Use absolute value to support negative numbers + n_abs = jnp.abs(n) max_val = 10**max_digits - 1 - n = jnp.minimum(n, max_val) + n_abs = jnp.minimum(n_abs, max_val) def scan_body(carry, _): digit = carry % 10 @@ -1344,7 +1344,7 @@ def scan_body(carry, _): # lax.scan on a batched `n` produces a shape of (length, batch_size). # On a scalar `n`, it produces a shape of (length,). - _, digits_reversed = jax.lax.scan(scan_body, n, None, length=max_digits) + _, digits_reversed = jax.lax.scan(scan_body, n_abs, None, length=max_digits) # Flip to get digits in the correct order (most significant first). digits = jnp.flip(digits_reversed, axis=0) From f6f943b3140d77d3534108b06db9d732be9dbb2e Mon Sep 17 00:00:00 2001 From: Quentin Delfosse Date: Wed, 6 May 2026 18:26:40 +0200 Subject: [PATCH 27/28] Asteroids mod: all --- src/jaxatari/games/mods/asteroids/__init__.py | 0 .../mods/asteroids/asteroids_mod_plugins.py | 345 ++++++++++++++++++ src/jaxatari/games/mods/asteroids_mods.py | 30 ++ 3 files changed, 375 insertions(+) create mode 100644 src/jaxatari/games/mods/asteroids/__init__.py create mode 100644 src/jaxatari/games/mods/asteroids/asteroids_mod_plugins.py create mode 100644 src/jaxatari/games/mods/asteroids_mods.py diff --git a/src/jaxatari/games/mods/asteroids/__init__.py b/src/jaxatari/games/mods/asteroids/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/jaxatari/games/mods/asteroids/asteroids_mod_plugins.py b/src/jaxatari/games/mods/asteroids/asteroids_mod_plugins.py new file mode 100644 index 000000000..e1eb98019 --- /dev/null +++ b/src/jaxatari/games/mods/asteroids/asteroids_mod_plugins.py @@ -0,0 +1,345 @@ +import jax.numpy as jnp +import numpy as np +import os +from jaxatari.modification import JaxAtariPostStepModPlugin, JaxAtariInternalModPlugin +from jaxatari.games.jax_asteroids import AsteroidsState +from jaxatari.environment import JAXAtariAction as Action +import jax.lax +from functools import partial +import jax +from jaxatari.rendering.jax_rendering_utils import get_base_sprite_dir + +class DontShootMod(JaxAtariPostStepModPlugin): + """ + Mod that provides a reward of 20 every 300 frames but penalizes shooting by -5. + This mod updates the state's score to reflect these changes. + """ + @partial(jax.jit, static_argnums=(0,)) + def run(self, prev_state: AsteroidsState, new_state: AsteroidsState): + periodic_reward = jnp.where( + (new_state.step_counter % 300 == 0) & (new_state.step_counter > 0), + 20, + 0 + ) + + missile_lifespan = self._env.consts.MISSILE_LIFESPAN + shot_fired = jnp.any( + (new_state.missile_states[:, 5] == missile_lifespan) & + (prev_state.missile_states[:, 5] != missile_lifespan) + ) + penalty = jnp.where(shot_fired, -5, 0) + + return new_state.replace( + score=new_state.score + periodic_reward + penalty + ) + +def _recolor_all(sprite: np.ndarray, new_rgb: tuple) -> np.ndarray: + """Replaces all non-transparent pixels with new_rgb.""" + sprite = sprite.copy() + if sprite.shape[-1] == 3: + is_transparent = (sprite[:, :, 0] == 0) & (sprite[:, :, 1] == 0) & (sprite[:, :, 2] == 0) + alpha = np.where(is_transparent, 0, 255).astype(np.uint8) + sprite = np.concatenate([sprite, alpha[..., None]], axis=-1) + mask = sprite[..., 3] > 128 + sprite[mask, :3] = new_rgb + return sprite + +def _load_and_recolor_group(filenames, new_rgb, transpose=False) -> list: + base_dir = os.path.join(get_base_sprite_dir(), "asteroids") + sprites = [] + for f in filenames: + sprite = np.load(os.path.join(base_dir, f)) + if transpose: + sprite = np.transpose(sprite, (1, 0, 2)) + sprites.append(jnp.array(_recolor_all(sprite, new_rgb))) + return sprites + +def _load_and_recolor_single(filename, new_rgb, transpose=False) -> jnp.ndarray: + base_dir = os.path.join(get_base_sprite_dir(), "asteroids") + sprite = np.load(os.path.join(base_dir, filename)) + if transpose: + sprite = np.transpose(sprite, (1, 0, 2)) + return jnp.array(_recolor_all(sprite, new_rgb)) + +def _get_player_group_recolored(): + player_files = [f'player_pos{i}.npy' for i in range(16)] + [f'death_player{i}.npy' for i in range(3)] + return _load_and_recolor_group(player_files, (150, 255, 150)) + +def _get_asteroid_group_recolored(): + asteroid_files = [] + for size in ['big1', 'big2', 'medium', 'small']: + for color in ['brown', 'grey', 'lightblue', 'lightyellow', 'pink', 'purple', 'red', 'yellow']: + asteroid_files.append(f'asteroid_{size}_{color}.npy') + for size in ['big', 'medium', 'small']: + for color in ['pink', 'yellow']: + asteroid_files.append(f'death_{size}_{color}.npy') + return _load_and_recolor_group(asteroid_files, (0, 200, 0)) + +def _get_digits_recolored() -> jnp.ndarray: + sprites = _load_and_recolor_group([f'{i}.npy' for i in range(10)], (0, 255, 0)) + max_height = max(s.shape[0] for s in sprites) + max_width = max(s.shape[1] for s in sprites) + padded_digits = [] + for digit in sprites: + digit = np.array(digit) + pad_h = max_height - digit.shape[0] + pad_w = max_width - digit.shape[1] + pad_top = pad_h // 2 + pad_bottom = pad_h - pad_top + pad_left = pad_w // 2 + pad_right = pad_w - pad_left + padded_digit = np.pad( + digit, + ((pad_top, pad_bottom), (pad_left, pad_right), (0, 0)), + mode="constant", + constant_values=0, + ) + padded_digits.append(padded_digit) + return jnp.stack([jnp.array(p) for p in padded_digits]) + +class MatrixMod(JaxAtariInternalModPlugin): + """A Matrix-themed mod for Asteroids: black background, green elements.""" + name = "matrix_theme" + + constants_overrides = { + 'WALL_COLOR': (0, 100, 0), + } + + asset_overrides = { + 'player_group': { + 'name': 'player_group', + 'type': 'group', + 'data': _get_player_group_recolored() + }, + 'asteroid_group': { + 'name': 'asteroid_group', + 'type': 'group', + 'data': _get_asteroid_group_recolored() + }, + 'missile1': { + 'name': 'missile1', + 'type': 'single', + 'data': _load_and_recolor_single('missile1.npy', (0, 255, 0)) + }, + 'missile2': { + 'name': 'missile2', + 'type': 'single', + 'data': _load_and_recolor_single('missile2.npy', (0, 255, 0)) + }, + 'digits': { + 'name': 'digits', + 'type': 'digits', + 'data': _get_digits_recolored() + }, + 'minus_sign': { + 'name': 'minus_sign', + 'type': 'procedural', + 'data': jnp.zeros((10, 12, 4), dtype=jnp.uint8).at[4:6, 2:10, :].set(jnp.array([0, 255, 0, 255], dtype=jnp.uint8)) + }, + 'wall_color': { + 'name': 'wall_color', + 'type': 'procedural', + 'data': jnp.array([0, 100, 0, 255], dtype=jnp.uint8).reshape(1, 1, 4) + } + } + +class SlowAsteroidsMod(JaxAtariInternalModPlugin): + """ + Mod that slows down asteroids by only updating their positions every 2nd frame. + """ + @partial(jax.jit, static_argnums=(0,)) + def asteroids_step(self, asteroids_state: AsteroidsState): + should_move = asteroids_state.step_counter % 2 == 0 + def _move_logic(_): + asteroid_states = asteroids_state.asteroid_states + side_step_counter = asteroids_state.side_step_counter + rng_key, subkey = jax.random.split(asteroids_state.rng_key) + counter_step = jax.random.randint(subkey, [], 7, 10) + side_step = jnp.logical_and(side_step_counter <= counter_step, side_step_counter != 0) + new_side_step_counter = jax.lax.cond( + side_step_counter < counter_step, + lambda: 115 + side_step_counter - counter_step, + lambda: side_step_counter - counter_step + ) + @jax.jit + def update_asteroid(i, asteroid_states): + ret = jnp.copy(asteroid_states) + axis_directions = jax.lax.switch( + ret[i][2], + [ + lambda: (self._env.consts.ASTEROID_SPEED[0], self._env.consts.ASTEROID_SPEED[1]), + lambda: (-self._env.consts.ASTEROID_SPEED[0], self._env.consts.ASTEROID_SPEED[1]), + lambda: (-self._env.consts.ASTEROID_SPEED[0], -self._env.consts.ASTEROID_SPEED[1]), + lambda: (self._env.consts.ASTEROID_SPEED[0], -self._env.consts.ASTEROID_SPEED[1]) + ] + ) + return ret.at[i].set(jax.lax.cond( + ret[i][3] != self._env.consts.INACTIVE, + lambda: jnp.array([self._env.final_pos(self._env.consts.MIN_ENTITY_X, + self._env.consts.MAX_ENTITY_X, + jax.lax.cond( + side_step, + lambda: ret[i][0] + axis_directions[0], + lambda: ret[i][0])), + self._env.final_pos(self._env.consts.MIN_ENTITY_Y, + self._env.consts.MAX_ENTITY_Y, + ret[i][1] + axis_directions[1]), + ret[i][2], ret[i][3], ret[i][4]]), + lambda: ret[i] + )) + new_asteroid_states = jax.lax.fori_loop(0, self._env.consts.MAX_NUMBER_OF_ASTEROIDS, update_asteroid, asteroid_states) + return new_asteroid_states, new_side_step_counter, rng_key + + def _no_move_logic(_): + return asteroids_state.asteroid_states, asteroids_state.side_step_counter, asteroids_state.rng_key + + return jax.lax.cond(should_move, _move_logic, _no_move_logic, operand=None) + +class InstantTurnMod(JaxAtariInternalModPlugin): + """Directly places the ship in the direction given by the action and applies thrust.""" + name = "instant_turn" + + attribute_overrides = { + "ACTION_SET": jnp.array( + [ + Action.NOOP, + Action.FIRE, + Action.UP, + Action.RIGHT, + Action.LEFT, + Action.DOWN, + Action.UPRIGHT, + Action.UPLEFT, + Action.DOWNRIGHT, + Action.DOWNLEFT, + Action.UPFIRE, + Action.RIGHTFIRE, + Action.LEFTFIRE, + Action.DOWNFIRE, + Action.UPRIGHTFIRE, + Action.UPLEFTFIRE, + Action.DOWNRIGHTFIRE, + Action.DOWNLEFTFIRE, + ], + dtype=jnp.int32, + ) + } + + @partial(jax.jit, static_argnums=(0,)) + def player_step( + self, + state_player_x, + state_player_y, + state_player_speed_x, + state_player_speed_y, + state_player_rotation, + action, + state_respawn_timer, + rng_key + ): + # 1. Parse actions into logical directions + left = jnp.logical_or(jnp.logical_or(action == Action.LEFT, action == Action.LEFTFIRE), + jnp.logical_or(jnp.logical_or(action == Action.UPLEFT, action == Action.UPLEFTFIRE), + jnp.logical_or(action == Action.DOWNLEFT, action == Action.DOWNLEFTFIRE))) + right = jnp.logical_or(jnp.logical_or(action == Action.RIGHT, action == Action.RIGHTFIRE), + jnp.logical_or(jnp.logical_or(action == Action.UPRIGHT, action == Action.UPRIGHTFIRE), + jnp.logical_or(action == Action.DOWNRIGHT, action == Action.DOWNRIGHTFIRE))) + up = jnp.logical_or(jnp.logical_or(action == Action.UP, action == Action.UPFIRE), + jnp.logical_or(jnp.logical_or(action == Action.UPLEFT, action == Action.UPLEFTFIRE), + jnp.logical_or(action == Action.UPRIGHT, action == Action.UPRIGHTFIRE))) + down = jnp.logical_or(jnp.logical_or(action == Action.DOWN, action == Action.DOWNFIRE), + jnp.logical_or(jnp.logical_or(action == Action.DOWNLEFT, action == Action.DOWNLEFTFIRE), + jnp.logical_or(action == Action.DOWNRIGHT, action == Action.DOWNRIGHTFIRE))) + + any_direction = jnp.logical_or(jnp.logical_or(up, down), jnp.logical_or(right, left)) + + # 2. Determine new rotation (instant) + # UP=0, UPLEFT=2, LEFT=4, DOWNLEFT=6, DOWN=8, DOWNRIGHT=10, RIGHT=12, UPRIGHT=14 + new_rotation = jax.lax.cond( + up, + lambda: jax.lax.cond(left, lambda: 2, lambda: jax.lax.cond(right, lambda: 14, lambda: 0)), + lambda: jax.lax.cond( + down, + lambda: jax.lax.cond(left, lambda: 6, lambda: jax.lax.cond(right, lambda: 10, lambda: 8)), + lambda: jax.lax.cond( + left, lambda: 4, lambda: jax.lax.cond(right, lambda: 12, lambda: state_player_rotation) + ) + ) + ) + + player_rotation = jax.lax.cond( + any_direction, + lambda: new_rotation, + lambda: state_player_rotation + ) + + # 3. Apply physics based on the new rotation + player_x = state_player_x + player_y = state_player_y + player_speed_x = state_player_speed_x + player_speed_y = state_player_speed_y + + decel_x = self._env.decel_func(player_speed_x) + decel_y = self._env.decel_func(player_speed_y) + + accel_x = self._env.consts.ACCEL_PER_ROTATION[player_rotation][0] + accel_y = self._env.consts.ACCEL_PER_ROTATION[player_rotation][1] + + # In instant turn mod, pressing any direction triggers thrust + is_thrusting = any_direction + + adj_speed_x = jnp.logical_and( + jnp.logical_and(is_thrusting, jnp.abs(player_speed_x + accel_x) < self._env.consts.MAX_PLAYER_SPEED), + jnp.logical_not(player_rotation%8 == 0)) + adj_speed_y = jnp.logical_and( + jnp.logical_and(is_thrusting, jnp.abs(player_speed_y + accel_y) < self._env.consts.MAX_PLAYER_SPEED), + jnp.logical_not((player_rotation-4)%8 == 0)) + + # calculate new player speed + player_speed_x = jax.lax.cond( + adj_speed_x, + lambda: player_speed_x + accel_x, + lambda: player_speed_x + ) + player_speed_x = jax.lax.cond( + jnp.logical_and(jnp.logical_not(adj_speed_x), jnp.abs(player_speed_x) > jnp.abs(decel_x)), + lambda: player_speed_x + decel_x, + lambda: player_speed_x + ) + player_speed_x = jax.lax.cond( + jnp.logical_and(jnp.logical_not(adj_speed_x), jnp.abs(player_speed_x) <= jnp.abs(decel_x)), + lambda: 0, + lambda: player_speed_x + ) + + player_speed_y = jax.lax.cond( + adj_speed_y, + lambda: player_speed_y + accel_y, + lambda: player_speed_y + ) + player_speed_y = jax.lax.cond( + jnp.logical_and(jnp.logical_not(adj_speed_y), jnp.abs(player_speed_y) > jnp.abs(decel_y)), + lambda: player_speed_y + decel_y, + lambda: player_speed_y + ) + player_speed_y = jax.lax.cond( + jnp.logical_and(jnp.logical_not(adj_speed_y), jnp.abs(player_speed_y) <= jnp.abs(decel_y)), + lambda: 0, + lambda: player_speed_y + ) + + displace_x = self._env.speed_func(player_speed_x) + displace_y = self._env.speed_func(player_speed_y) + + player_x = jnp.int32(self._env.final_pos(self._env.consts.MIN_PLAYER_X, self._env.consts.MAX_PLAYER_X, player_x + displace_x)) + player_y = jnp.int32(self._env.final_pos(self._env.consts.MIN_PLAYER_Y, self._env.consts.MAX_PLAYER_Y, player_y + displace_y)) + + # We remove hyperspace (down) entirely so you can fly down without teleporting + + return jax.lax.cond( + state_respawn_timer <= 0, + lambda: (player_x, player_y, player_speed_x, player_speed_y, + player_rotation, state_respawn_timer, rng_key), + lambda: (state_player_x, state_player_y, state_player_speed_x, state_player_speed_y, + state_player_rotation, state_respawn_timer, rng_key) + ) diff --git a/src/jaxatari/games/mods/asteroids_mods.py b/src/jaxatari/games/mods/asteroids_mods.py new file mode 100644 index 000000000..f269b8359 --- /dev/null +++ b/src/jaxatari/games/mods/asteroids_mods.py @@ -0,0 +1,30 @@ +import os +from jaxatari.modification import JaxAtariModController +from jaxatari.games.mods.asteroids.asteroids_mod_plugins import DontShootMod, MatrixMod, SlowAsteroidsMod, InstantTurnMod + +class AsteroidsEnvMod(JaxAtariModController): + """ + Game-specific Mod Controller for Asteroids. + """ + + REGISTRY = { + "dont_shoot": DontShootMod, + "matrix_theme": MatrixMod, + "slow_asteroids": SlowAsteroidsMod, + "instant_turn": InstantTurnMod, + } + + _mod_sprite_dir = os.path.join(os.path.dirname(__file__), "asteroids", "sprites") + + def __init__(self, + env, + mods_config: list = [], + allow_conflicts: bool = False + ): + + super().__init__( + env=env, + mods_config=mods_config, + allow_conflicts=allow_conflicts, + registry=self.REGISTRY + ) From 880af7c549ee6c6594b4e2f76f9d4ffdb7357f24 Mon Sep 17 00:00:00 2001 From: Quentin Delfosse Date: Sat, 16 May 2026 10:39:26 +0200 Subject: [PATCH 28/28] Time Pilot mod precision --- src/jaxatari/games/mods/timepilot/timepilot_mod_plugins.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/jaxatari/games/mods/timepilot/timepilot_mod_plugins.py b/src/jaxatari/games/mods/timepilot/timepilot_mod_plugins.py index 7781da177..387b791cd 100644 --- a/src/jaxatari/games/mods/timepilot/timepilot_mod_plugins.py +++ b/src/jaxatari/games/mods/timepilot/timepilot_mod_plugins.py @@ -106,7 +106,8 @@ class ReverseChronologyMod(JaxAtariInternalModPlugin): } class InstantTurnMod(JaxAtariInternalModPlugin): - """Directly places the plane in the direction given by the action instead of progressively turning it, including diagonals.""" + """Directly places the plane in the direction given by the action instead of progressively turning it, including diagonals. + Also moves the ship.""" name = "instant_turn" # Override the game's ACTION_SET to include diagonal actions