From 0c4db6a321aabdfa5f2c9c9548cbdd7c5daa71e4 Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Thu, 6 Nov 2025 17:09:00 +0100 Subject: [PATCH 01/76] initial commit using pong as template --- src/jaxatari/games/UpNDown.py | 663 ++++++++++++++++++++++++++++++++++ 1 file changed, 663 insertions(+) create mode 100644 src/jaxatari/games/UpNDown.py diff --git a/src/jaxatari/games/UpNDown.py b/src/jaxatari/games/UpNDown.py new file mode 100644 index 000000000..904a45bb3 --- /dev/null +++ b/src/jaxatari/games/UpNDown.py @@ -0,0 +1,663 @@ +from jax._src.pjit import JitWrapped +import os +from functools import partial +from typing import NamedTuple, Tuple +import jax.lax +import jax.numpy as jnp +import chex + +import jaxatari.spaces as spaces +from jaxatari.renderers import JAXGameRenderer +from jaxatari.rendering import jax_rendering_utils as render_utils +from jaxatari.environment import JaxEnvironment, JAXAtariAction as Action + +class UpNDownConstants(NamedTuple): + FRAME_SKIP: int = 4 + DIFFICULTIES: chex.Array = jnp.array([0, 1, 2, 3, 4, 5]) + ACTION_REPEAT_PROBS: float = 0.25 + + +# immutable state container +class UpNDownState(NamedTuple): + player_y: chex.Array + player_speed: chex.Array + score: chex.Array + difficulty: chex.Array + + + +class EntityPosition(NamedTuple): + x: jnp.ndarray + y: jnp.ndarray + width: jnp.ndarray + height: jnp.ndarray + + +class EnemyCar(NamedTuple): + position: EntityPosition + speed: chex.Array + type: chex.Array + + +class UpNDownObservation(NamedTuple): + player: EntityPosition + enemies: jnp.ndarray + score: jnp.ndarray + +class Collectible(NamedTuple): + position: EntityPosition + type: chex.Array + value: chex.Array + + +class UpNDownInfo(NamedTuple): + time: jnp.ndarray + + +class JaxUpNDown(JaxEnvironment[UpNDownState, UpNDownObservation, UpNDownInfo, UpNDownConstants]): + def __init__(self, consts: UpNDownConstants = None, reward_funcs: list[callable]=None): + consts = consts or UpNDownConstants() + super().__init__(consts) + self.renderer = UpNDownRenderer(self.consts) + if reward_funcs is not None: + reward_funcs = tuple(reward_funcs) + self.reward_funcs = reward_funcs + self.action_set = [ + Action.NOOP, + Action.FIRE, + Action.RIGHT, + Action.LEFT, + Action.RIGHTFIRE, + Action.LEFTFIRE, + ] + self.obs_size = 3*4+1+1 + + def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: + up = jnp.logical_or(action == Action.LEFT, action == Action.LEFTFIRE) + down = jnp.logical_or(action == Action.RIGHT, action == Action.RIGHTFIRE) + + acceleration = self.consts.PLAYER_ACCELERATION[state.acceleration_counter] + + touches_wall = jnp.logical_or( + state.player_y < self.consts.WALL_TOP_Y, + state.player_y + self.consts.PLAYER_SIZE[1] > self.consts.WALL_BOTTOM_Y, + ) + + player_speed = state.player_speed + + player_speed = jax.lax.cond( + jnp.logical_or(jnp.logical_not(jnp.logical_or(up, down)), touches_wall), + lambda s: jnp.round(s / 2).astype(jnp.int32), + lambda s: s, + operand=player_speed, + ) + + direction_change_up = jnp.logical_and(up, state.player_speed > 0) + player_speed = jax.lax.cond( + direction_change_up, + lambda s: 0, + lambda s: s, + operand=player_speed, + ) + direction_change_down = jnp.logical_and(down, state.player_speed < 0) + + player_speed = jax.lax.cond( + direction_change_down, + lambda s: 0, + lambda s: s, + operand=player_speed, + ) + + direction_change = jnp.logical_or(direction_change_up, direction_change_down) + acceleration_counter = jax.lax.cond( + direction_change, + lambda _: 0, + lambda s: s, + operand=state.acceleration_counter, + ) + + player_speed = jax.lax.cond( + up, + lambda s: jnp.maximum(s - acceleration, -self.consts.MAX_SPEED), + lambda s: s, + operand=player_speed, + ) + + player_speed = jax.lax.cond( + down, + lambda s: jnp.minimum(s + acceleration, self.consts.MAX_SPEED), + lambda s: s, + operand=player_speed, + ) + + new_acceleration_counter = jax.lax.cond( + jnp.logical_or(up, down), + lambda s: jnp.minimum(s + 1, 15), + lambda s: 0, + operand=acceleration_counter, + ) + + proposed_player_y = jnp.clip( + state.player_y + player_speed, + self.consts.WALL_TOP_Y + self.consts.WALL_TOP_HEIGHT - 10, + self.consts.WALL_BOTTOM_Y - 4, + ) + + # Match original timing/buffering behavior + new_player_y, new_player_speed, new_acc_counter = jax.lax.cond( + state.step_counter % 2 == 0, + lambda _: (proposed_player_y, player_speed, new_acceleration_counter), + lambda _: (state.player_y, state.player_speed, state.acceleration_counter), + operand=None, + ) + + buffer = jax.lax.cond( + jax.lax.eq(state.buffer, state.player_y), + lambda _: new_player_y, + lambda _: state.buffer, + operand=None, + ) + final_player_y = state.buffer + + return UpNDownState( + player_y=final_player_y, + player_speed=new_player_speed, + ball_x=state.ball_x, + ball_y=state.ball_y, + enemy_y=state.enemy_y, + enemy_speed=state.enemy_speed, + ball_vel_x=state.ball_vel_x, + ball_vel_y=state.ball_vel_y, + player_score=state.player_score, + enemy_score=state.enemy_score, + step_counter=state.step_counter, + acceleration_counter=new_acc_counter, + buffer=buffer, + ) + + def _ball_step(self, state: UpNDownState, action) -> UpNDownState: + ball_x = state.ball_x + state.ball_vel_x + ball_y = state.ball_y + state.ball_vel_y + + wall_bounce = jnp.logical_or( + ball_y <= self.consts.WALL_TOP_Y + self.consts.WALL_TOP_HEIGHT - self.consts.BALL_SIZE[1], + ball_y >= self.consts.WALL_BOTTOM_Y, + ) + ball_vel_y = jnp.where(wall_bounce, -state.ball_vel_y, state.ball_vel_y) + + player_paddle_hit = jnp.logical_and( + jnp.logical_and(self.consts.PLAYER_X <= ball_x, ball_x <= self.consts.PLAYER_X + self.consts.PLAYER_SIZE[0]), + state.ball_vel_x > 0, + ) + + player_paddle_hit = jnp.logical_and( + player_paddle_hit, + jnp.logical_and( + state.player_y - self.consts.BALL_SIZE[1] <= ball_y, + ball_y <= state.player_y + self.consts.PLAYER_SIZE[1] + self.consts.BALL_SIZE[1], + ), + ) + + enemy_paddle_hit = jnp.logical_and( + jnp.logical_and(self.consts.ENEMY_X <= ball_x, ball_x <= self.consts.ENEMY_X + self.consts.ENEMY_SIZE[0] - 1), + state.ball_vel_x < 0, + ) + + enemy_paddle_hit = jnp.logical_and( + enemy_paddle_hit, + jnp.logical_and( + state.enemy_y - self.consts.BALL_SIZE[1] <= ball_y, + ball_y <= state.enemy_y + self.consts.ENEMY_SIZE[1] + self.consts.BALL_SIZE[1], + ), + ) + + paddle_hit = jnp.logical_or(player_paddle_hit, enemy_paddle_hit) + + section_height = self.consts.PLAYER_SIZE[1] / 5 + + hit_position = jnp.where( + paddle_hit, + jnp.where( + player_paddle_hit, + jnp.where( + ball_y < state.player_y + section_height, + -2.0, + jnp.where( + ball_y < state.player_y + 2 * section_height, + -1.0, + jnp.where( + ball_y < state.player_y + 3 * section_height, + 0.0, + jnp.where( + ball_y < state.player_y + 4 * section_height, + 1.0, + 2.0, + ), + ), + ), + ), + jnp.where( + ball_y < state.enemy_y + section_height, + -2.0, + jnp.where( + ball_y < state.enemy_y + 2 * section_height, + -1.0, + jnp.where( + ball_y < state.enemy_y + 3 * section_height, + 0.0, + jnp.where( + ball_y < state.enemy_y + 4 * section_height, + 1.0, + 2.0, + ), + ), + ), + ), + ), + 0.0, + ) + + paddle_speed = jnp.where( + player_paddle_hit, + state.player_speed, + jnp.where( + enemy_paddle_hit, + state.enemy_speed, + 0.0, + ), + ) + + ball_vel_y = jnp.where(paddle_hit, hit_position, ball_vel_y) + + boost_triggered = jnp.logical_and( + player_paddle_hit, + jnp.logical_or( + jnp.logical_or(action == Action.LEFTFIRE, action == Action.RIGHTFIRE), + action == Action.FIRE, + ), + ) + player_max_hit = jnp.logical_and(player_paddle_hit, state.player_speed == self.consts.MAX_SPEED) + ball_vel_x = jnp.where( + jnp.logical_or(boost_triggered, player_max_hit), + state.ball_vel_x + + jnp.sign(state.ball_vel_x), + state.ball_vel_x, + ) + + ball_vel_x = jnp.where( + paddle_hit, + -ball_vel_x, + ball_vel_x, + ) + + return UpNDownState( + player_y=state.player_y, + player_speed=state.player_speed, + ball_x=ball_x.astype(jnp.int32), + ball_y=ball_y.astype(jnp.int32), + enemy_y=state.enemy_y, + enemy_speed=state.enemy_speed, + ball_vel_x=ball_vel_x.astype(jnp.int32), + ball_vel_y=ball_vel_y.astype(jnp.int32), + player_score=state.player_score, + enemy_score=state.enemy_score, + step_counter=state.step_counter, + acceleration_counter=state.acceleration_counter, + buffer=state.buffer, + ) + + def _enemy_step(self, state: UpNDownState) -> UpNDownState: + should_move = state.step_counter % 8 != 0 + + direction = jnp.sign(state.ball_y - state.enemy_y) + + new_y = state.enemy_y + (direction * self.consts.ENEMY_STEP_SIZE).astype(jnp.int32) + enemy_y = jax.lax.cond( + should_move, lambda _: new_y, lambda _: state.enemy_y, operand=None + ) + return UpNDownState( + player_y=state.player_y, + player_speed=state.player_speed, + ball_x=state.ball_x, + ball_y=state.ball_y, + enemy_y=enemy_y.astype(jnp.int32), + enemy_speed=state.enemy_speed, + ball_vel_x=state.ball_vel_x, + ball_vel_y=state.ball_vel_y, + player_score=state.player_score, + enemy_score=state.enemy_score, + step_counter=state.step_counter, + acceleration_counter=state.acceleration_counter, + buffer=state.buffer, + ) + + def _score_and_reset(self, state: UpNDownState) -> UpNDownState: + player_goal = state.ball_x < 4 + enemy_goal = state.ball_x > 156 + ball_reset = jnp.logical_or(enemy_goal, player_goal) + + player_score = jax.lax.cond( + player_goal, + lambda s: s + 1, + lambda s: s, + operand=state.player_score, + ) + enemy_score = jax.lax.cond( + enemy_goal, + lambda s: s + 1, + lambda s: s, + operand=state.enemy_score, + ) + + current_values = ( + state.ball_x.astype(jnp.int32), + state.ball_y.astype(jnp.int32), + state.ball_vel_x.astype(jnp.int32), + state.ball_vel_y.astype(jnp.int32), + ) + ball_x_final, ball_y_final, ball_vel_x_final, ball_vel_y_final = jax.lax.cond( + ball_reset, + lambda x: self._reset_ball_after_goal((state, enemy_goal)), + lambda x: x, + operand=current_values, + ) + + step_counter = jax.lax.cond( + ball_reset, + lambda s: jnp.array(0), + lambda s: s + 1, + operand=state.step_counter, + ) + + enemy_y_final = jax.lax.cond( + ball_reset, + lambda s: self.consts.BALL_START_Y.astype(jnp.int32), + lambda s: state.enemy_y.astype(jnp.int32), + operand=None, + ) + + ball_x_final = jax.lax.cond( + step_counter < 60, + lambda s: self.consts.BALL_START_X.astype(jnp.int32), + lambda s: s, + operand=ball_x_final, + ) + ball_y_final = jax.lax.cond( + step_counter < 60, + lambda s: self.consts.BALL_START_Y.astype(jnp.int32), + lambda s: s, + operand=ball_y_final, + ) + + return UpNDownState( + player_y=state.player_y, + player_speed=state.player_speed, + ball_x=ball_x_final, + ball_y=ball_y_final, + enemy_y=enemy_y_final, + enemy_speed=state.enemy_speed, + ball_vel_x=ball_vel_x_final, + ball_vel_y=ball_vel_y_final, + player_score=player_score, + enemy_score=enemy_score, + step_counter=step_counter, + acceleration_counter=state.acceleration_counter, + buffer=state.buffer, + ) + + def _reset_ball_after_goal(self, state_and_goal: Tuple[UpNDownState, bool]) -> Tuple[chex.Array, chex.Array, chex.Array, chex.Array]: + state, scored_right = state_and_goal + + ball_vel_y = jnp.where( + state.ball_y > self.consts.BALL_START_Y, + 1, + -1, + ).astype(jnp.int32) + + ball_vel_x = jnp.where( + scored_right, 1, -1 + ).astype(jnp.int32) + + return ( + self.consts.BALL_START_X.astype(jnp.int32), + self.consts.BALL_START_Y.astype(jnp.int32), + ball_vel_x.astype(jnp.int32), + ball_vel_y.astype(jnp.int32), + ) + + def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: + state = UpNDownState( + player_y=jnp.array(96).astype(jnp.int32), + player_speed=jnp.array(0.0).astype(jnp.int32), + ball_x=jnp.array(78).astype(jnp.int32), + ball_y=jnp.array(115).astype(jnp.int32), + enemy_y=jnp.array(115).astype(jnp.int32), + enemy_speed=jnp.array(0.0).astype(jnp.int32), + ball_vel_x=self.consts.BALL_SPEED[0].astype(jnp.int32), + ball_vel_y=self.consts.BALL_SPEED[1].astype(jnp.int32), + player_score=jnp.array(0).astype(jnp.int32), + enemy_score=jnp.array(0).astype(jnp.int32), + step_counter=jnp.array(0).astype(jnp.int32), + acceleration_counter=jnp.array(0).astype(jnp.int32), + buffer=jnp.array(96).astype(jnp.int32), + ) + initial_obs = self._get_observation(state) + + return initial_obs, state + + @partial(jax.jit, static_argnums=(0,)) + def step(self, state: UpNDownState, action: chex.Array) -> Tuple[UpNDownObservation, UpNDownState, float, bool, UpNDownInfo]: + previous_state = state + state = self._player_step(state, action) + state = self._enemy_step(state) + state = self._ball_step(state, action) + state = self._score_and_reset(state) + + done = self._get_done(state) + env_reward = self._get_reward(previous_state, state) + info = self._get_info(state) + observation = self._get_observation(state) + + return observation, state, env_reward, done, info + + + def render(self, state: UpNDownState) -> jnp.ndarray: + return self.renderer.render(state) + + def _get_observation(self, state: UpNDownState): + player = EntityPosition( + x=jnp.array(self.consts.PLAYER_X), + y=state.player_y, + width=jnp.array(self.consts.PLAYER_SIZE[0]), + height=jnp.array(self.consts.PLAYER_SIZE[1]), + ) + + enemy = EntityPosition( + x=jnp.array(self.consts.ENEMY_X), + y=state.enemy_y, + width=jnp.array(self.consts.ENEMY_SIZE[0]), + height=jnp.array(self.consts.ENEMY_SIZE[1]), + ) + + ball = EntityPosition( + x=state.ball_x, + y=state.ball_y, + width=jnp.array(self.consts.BALL_SIZE[0]), + height=jnp.array(self.consts.BALL_SIZE[1]), + ) + return UpNDownObservation( + player=player, + enemy=enemy, + ball=ball, + score_player=state.player_score, + score_enemy=state.enemy_score, + ) + + @partial(jax.jit, static_argnums=(0,)) + def obs_to_flat_array(self, obs: UpNDownObservation) -> jnp.ndarray: + return jnp.concatenate([ + obs.player.x.flatten(), + obs.player.y.flatten(), + obs.player.height.flatten(), + obs.player.width.flatten(), + obs.enemy.x.flatten(), + obs.enemy.y.flatten(), + obs.enemy.height.flatten(), + obs.enemy.width.flatten(), + obs.ball.x.flatten(), + obs.ball.y.flatten(), + obs.ball.height.flatten(), + obs.ball.width.flatten(), + obs.score_player.flatten(), + obs.score_enemy.flatten() + ] + ) + + def action_space(self) -> spaces.Discrete: + return spaces.Discrete(6) + + def observation_space(self) -> spaces: + return spaces.Dict({ + "player": spaces.Dict({ + "x": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), + "y": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), + "width": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), + "height": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), + }), + "enemy": spaces.Dict({ + "x": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), + "y": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), + "width": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), + "height": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), + }), + "ball": spaces.Dict({ + "x": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), + "y": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), + "width": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), + "height": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), + }), + "score_player": spaces.Box(low=0, high=21, shape=(), dtype=jnp.int32), + "score_enemy": spaces.Box(low=0, high=21, shape=(), dtype=jnp.int32), + }) + + def image_space(self) -> spaces.Box: + return spaces.Box( + low=0, + high=255, + shape=(210, 160, 3), + dtype=jnp.uint8 + ) + + @partial(jax.jit, static_argnums=(0,)) + def _get_info(self, state: UpNDownState, ) -> UpNDownInfo: + return UpNDownInfo(time=state.step_counter) + + @partial(jax.jit, static_argnums=(0,)) + def _get_reward(self, previous_state: UpNDownState, state: UpNDownState): + return (state.player_score - state.enemy_score) - ( + previous_state.player_score - previous_state.enemy_score + ) + + @partial(jax.jit, static_argnums=(0,)) + def _get_done(self, state: UpNDownState) -> bool: + return jnp.logical_or( + jnp.greater_equal(state.player_score, 21), + jnp.greater_equal(state.enemy_score, 21), + ) + +class UpNDownRenderer(JAXGameRenderer): + def __init__(self, consts: UpNDownConstants = None): + super().__init__() + self.consts = consts or UpNDownConstants() + self.config = render_utils.RendererConfig( + game_dimensions=(210, 160), + channels=3, + #downscale=(84, 84) + ) + self.jr = render_utils.JaxRenderingUtils(self.config) + # 1. Create procedural assets for both walls + wall_sprite_top = self._create_wall_sprite(self.consts.WALL_TOP_HEIGHT) + wall_sprite_bottom = self._create_wall_sprite(self.consts.WALL_BOTTOM_HEIGHT) + + # 2. Update asset config to include both walls + asset_config = self._get_asset_config(wall_sprite_top, wall_sprite_bottom) + sprite_path = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/UpNDown" + + # 3. Make a single call to the setup function + ( + self.PALETTE, + self.SHAPE_MASKS, + self.BACKGROUND, + self.COLOR_TO_ID, + self.FLIP_OFFSETS + ) = self.jr.load_and_setup_assets(asset_config, sprite_path) + + def _create_wall_sprite(self, height: int) -> jnp.ndarray: + """Procedurally creates an RGBA sprite for a wall of given height.""" + wall_color_rgba = (*self.consts.SCORE_COLOR, 255) # e.g., (236, 236, 236, 255) + wall_shape = (height, self.consts.WIDTH, 4) + wall_sprite = jnp.tile(jnp.array(wall_color_rgba, dtype=jnp.uint8), (*wall_shape[:2], 1)) + return wall_sprite + + def _get_asset_config(self, wall_sprite_top: jnp.ndarray, wall_sprite_bottom: jnp.ndarray) -> list: + """Returns the declarative manifest of all assets for the game, including both wall sprites.""" + return [ + {'name': 'background', 'type': 'background', 'file': 'background.npy'}, + {'name': 'player', 'type': 'single', 'file': 'player.npy'}, + {'name': 'enemy', 'type': 'single', 'file': 'enemy.npy'}, + {'name': 'ball', 'type': 'single', 'file': 'ball.npy'}, + {'name': 'player_digits', 'type': 'digits', 'pattern': 'player_score_{}.npy'}, + {'name': 'enemy_digits', 'type': 'digits', 'pattern': 'enemy_score_{}.npy'}, + # Add the procedurally created sprites to the manifest + {'name': 'wall_top', 'type': 'procedural', 'data': wall_sprite_top}, + {'name': 'wall_bottom', 'type': 'procedural', 'data': wall_sprite_bottom}, + ] + + @partial(jax.jit, static_argnums=(0,)) + def render(self, state): + raster = self.jr.create_object_raster(self.BACKGROUND) + + player_mask = self.SHAPE_MASKS["player"] + raster = self.jr.render_at(raster, self.consts.PLAYER_X, state.player_y, player_mask) + + enemy_mask = self.SHAPE_MASKS["enemy"] + raster = self.jr.render_at(raster, self.consts.ENEMY_X, state.enemy_y, enemy_mask) + + ball_mask = self.SHAPE_MASKS["ball"] + raster = self.jr.render_at(raster, state.ball_x, state.ball_y, ball_mask) + + # --- Stamp Walls and Score (using the same color/ID) --- + score_color_tuple = self.consts.SCORE_COLOR # (236, 236, 236) + score_id = self.COLOR_TO_ID[score_color_tuple] + + # Draw walls (using separate sprites for top and bottom) + raster = self.jr.render_at(raster, 0, self.consts.WALL_TOP_Y, self.SHAPE_MASKS["wall_top"]) + raster = self.jr.render_at(raster, 0, self.consts.WALL_BOTTOM_Y, self.SHAPE_MASKS["wall_bottom"]) + + # Stamp Score using the label utility + player_digits = self.jr.int_to_digits(state.player_score, max_digits=2) + enemy_digits = self.jr.int_to_digits(state.enemy_score, max_digits=2) + + # Note: The logic for single/double digits is complex for a jitted function. + player_digit_masks = self.SHAPE_MASKS["player_digits"] # Assumes single color + enemy_digit_masks = self.SHAPE_MASKS["enemy_digits"] # Assumes single color + + is_player_single_digit = state.player_score < 10 + player_start_index = jax.lax.select(is_player_single_digit, 1, 0) + player_num_to_render = jax.lax.select(is_player_single_digit, 1, 2) + player_render_x = jax.lax.select(is_player_single_digit, + 120 + 16 // 2, + 120) + + raster = self.jr.render_label_selective(raster, player_render_x, 3, player_digits, player_digit_masks, player_start_index, player_num_to_render, spacing=16) + + is_enemy_single_digit = state.enemy_score < 10 + enemy_start_index = jax.lax.select(is_enemy_single_digit, 1, 0) + enemy_num_to_render = jax.lax.select(is_enemy_single_digit, 1, 2) + enemy_render_x = jax.lax.select(is_enemy_single_digit, + 10 + 16 // 2, + 10) + + raster = self.jr.render_label_selective(raster, enemy_render_x, 3, enemy_digits, enemy_digit_masks, enemy_start_index, enemy_num_to_render, spacing=16) + + return self.jr.render_from_palette(raster, self.PALETTE) \ No newline at end of file From 16b723bc39431b1317c3d53745ca2503b99ee5eb Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Fri, 7 Nov 2025 17:52:12 +0100 Subject: [PATCH 02/76] rough design of potential car movements --- src/jaxatari/games/UpNDown.py | 402 +++++++++++++--------------------- 1 file changed, 157 insertions(+), 245 deletions(-) diff --git a/src/jaxatari/games/UpNDown.py b/src/jaxatari/games/UpNDown.py index 904a45bb3..af62fe461 100644 --- a/src/jaxatari/games/UpNDown.py +++ b/src/jaxatari/games/UpNDown.py @@ -15,28 +15,44 @@ class UpNDownConstants(NamedTuple): FRAME_SKIP: int = 4 DIFFICULTIES: chex.Array = jnp.array([0, 1, 2, 3, 4, 5]) ACTION_REPEAT_PROBS: float = 0.25 - - -# immutable state container -class UpNDownState(NamedTuple): - player_y: chex.Array - player_speed: chex.Array - score: chex.Array - difficulty: chex.Array + MAX_SPEED: int = 4 + JUMP_FRAMES: int = 10 + LANDING_ZONE: int = 15 + FIRST_ROAD_LENGTH: int = 4 + SECOND_ROAD_LENGTH: int = 4 + FIRST_TRACK_CORNERS_X: chex.Array = jnp.array([20, 50, 80, 100]) #get actual values + FIRST_TRACK_CORNERS_Y: chex.Array = jnp.array([20, 50, 80, 100]) #get actual values + SECOND_TRACK_CORNERS_X: chex.Array = jnp.array([20, 50, 80, 100]) #get actual values + SECOND_TRACK_CORNERS_Y: chex.Array = jnp.array([20, 50, 80, 100]) #get actual values +# immutable state container class EntityPosition(NamedTuple): x: jnp.ndarray y: jnp.ndarray width: jnp.ndarray height: jnp.ndarray - -class EnemyCar(NamedTuple): +class Car(NamedTuple): position: EntityPosition speed: chex.Array type: chex.Array + current_road: chex.Array + road_index_A: chex.Array + road_index_B: chex.Array + direction_x: chex.Array + +class UpNDownState(NamedTuple): + score: chex.Array + difficulty: chex.Array + road_index: chex.Array + jump_cooldown: chex.Array + is_jumping: chex.Array + is_on_road: chex.Array + player_car: Car + + class UpNDownObservation(NamedTuple): @@ -65,270 +81,186 @@ def __init__(self, consts: UpNDownConstants = None, reward_funcs: list[callable] self.action_set = [ Action.NOOP, Action.FIRE, - Action.RIGHT, - Action.LEFT, - Action.RIGHTFIRE, - Action.LEFTFIRE, + Action.UPFIRE, + Action.UP, + Action.DOWN, + Action.DOWNFIRE, ] self.obs_size = 3*4+1+1 - def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: - up = jnp.logical_or(action == Action.LEFT, action == Action.LEFTFIRE) - down = jnp.logical_or(action == Action.RIGHT, action == Action.RIGHTFIRE) - - acceleration = self.consts.PLAYER_ACCELERATION[state.acceleration_counter] - - touches_wall = jnp.logical_or( - state.player_y < self.consts.WALL_TOP_Y, - state.player_y + self.consts.PLAYER_SIZE[1] > self.consts.WALL_BOTTOM_Y, + @partial(jax.jit, static_argnums=(0,)) + def _car_past_corner(self, car: Car, state: UpNDownState) -> chex.Array: + direction_change_A = jnp.logical_or(jnp.logical_and(car.speed > 0, car.position.y + car.speed > self.consts.FIRST_TRACK_CORNERS_Y[car.road_index+1]), jnp.logical_and(car.speed < 0, car.position.y + car.speed < self.consts.FIRST_TRACK_CORNERS_Y[car.road_index])) + direction_change_B = jnp.logical_or(jnp.logical_and(car.speed > 0, car.position.y + car.speed > self.consts.SECOND_TRACK_CORNERS_Y[car.road_index+1]), jnp.logical_and(car.speed < 0, car.position.y + car.speed < self.consts.SECOND_TRACK_CORNERS_Y[car.road_index])), + + road_index_A = jax.lax.cond(jnp.logical_and(direction_change_A, car.speed > 0), + lambda s: s + 1, + lambda s: s, + operand=car.road_index_A, ) - - player_speed = state.player_speed - - player_speed = jax.lax.cond( - jnp.logical_or(jnp.logical_not(jnp.logical_or(up, down)), touches_wall), - lambda s: jnp.round(s / 2).astype(jnp.int32), + road_index_A = jax.lax.cond(jnp.logical_and(direction_change_A, car.speed < 0), + lambda s: s - 1, lambda s: s, - operand=player_speed, + operand=car.road_index_A, ) - direction_change_up = jnp.logical_and(up, state.player_speed > 0) - player_speed = jax.lax.cond( - direction_change_up, - lambda s: 0, + road_index_B = jax.lax.cond(jnp.logical_and(direction_change_B, car.speed > 0), + lambda s: s + 1, lambda s: s, - operand=player_speed, + operand=car.road_index_B, ) - direction_change_down = jnp.logical_and(down, state.player_speed < 0) - - player_speed = jax.lax.cond( - direction_change_down, - lambda s: 0, + road_index_B = jax.lax.cond(jnp.logical_and(direction_change_B, car.speed < 0), + lambda s: s - 1, lambda s: s, - operand=player_speed, + operand=car.road_index_B, ) + current_road_length_A = self.consts.FIRST_ROAD_LENGTH + current_road_length_B = self.consts.SECOND_ROAD_LENGTH - direction_change = jnp.logical_or(direction_change_up, direction_change_down) - acceleration_counter = jax.lax.cond( - direction_change, - lambda _: 0, + road_index_A = jax.lax.cond(road_index_A < 0, + lambda s: current_road_length_A - 1, lambda s: s, - operand=state.acceleration_counter, + operand=road_index_A, ) - player_speed = jax.lax.cond( - up, - lambda s: jnp.maximum(s - acceleration, -self.consts.MAX_SPEED), + road_index_A = jax.lax.cond(road_index_A >= current_road_length_A, + lambda s: 0, lambda s: s, - operand=player_speed, + operand=road_index_A, ) - player_speed = jax.lax.cond( - down, - lambda s: jnp.minimum(s + acceleration, self.consts.MAX_SPEED), + road_index_B = jax.lax.cond(road_index_B < 0, + lambda s: current_road_length_B - 1, lambda s: s, - operand=player_speed, + operand=road_index_B, ) - new_acceleration_counter = jax.lax.cond( - jnp.logical_or(up, down), - lambda s: jnp.minimum(s + 1, 15), + road_index_B = jax.lax.cond(road_index_B >= current_road_length_B, lambda s: 0, - operand=acceleration_counter, + lambda s: s, + operand=road_index_B, ) - proposed_player_y = jnp.clip( - state.player_y + player_speed, - self.consts.WALL_TOP_Y + self.consts.WALL_TOP_HEIGHT - 10, - self.consts.WALL_BOTTOM_Y - 4, - ) + return road_index_A, road_index_B + + @partial(jax.jit, static_argnums=(0,)) + def _landing_in_water(self, state: UpNDownState, new_position_x: chex.Array, new_position_y: chex.Array) -> chex.Array: + road_A_x = ((new_position_y - self.consts.FIRST_TRACK_CORNERS_Y[state.player_car.road_index_A]) / (self.consts.FIRST_TRACK_CORNERS_Y[state.player_car.road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_Y[state.player_car.road_index_A])) * (self.consts.FIRST_TRACK_CORNERS_X[state.player_car.road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_X[state.player_car.road_index_A]) + self.consts.FIRST_TRACK_CORNERS_X[state.player_car.road_index_A] + road_B_x = ((new_position_y - self.consts.SECOND_TRACK_CORNERS_Y[state.player_car.road_index_B]) / (self.consts.SECOND_TRACK_CORNERS_Y[state.player_car.road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_Y[state.player_car.road_index_B])) * (self.consts.SECOND_TRACK_CORNERS_X[state.player_car.road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[state.player_car.road_index_B]) + self.consts.SECOND_TRACK_CORNERS_X[state.player_car.road_index_B] + distance_to_road_A = jnp.abs(new_position_x - road_A_x) + distance_to_road_B = jnp.abs(new_position_x - road_B_x) + landing_in_Water = jnp.logical_and(distance_to_road_A > self.consts.LANDING_ZONE, distance_to_road_B > self.consts.LANDING_ZONE) + between_roads = jnp.logical_and(new_position_x > jnp.minimum(road_A_x, road_B_x), new_position_x < jnp.maximum(road_A_x, road_B_x)) + return landing_in_Water, between_roads - # Match original timing/buffering behavior - new_player_y, new_player_speed, new_acc_counter = jax.lax.cond( - state.step_counter % 2 == 0, - lambda _: (proposed_player_y, player_speed, new_acceleration_counter), - lambda _: (state.player_y, state.player_speed, state.acceleration_counter), - operand=None, - ) + def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: + up = jnp.logical_or(action == Action.UP, action == Action.UPFIRE) + down = jnp.logical_or(action == Action.DOWN, action == Action.DOWNFIRE) + jump = jnp.logical_or(action == Action.FIRE, action == Action.UPFIRE, action == Action.DOWNFIRE) - buffer = jax.lax.cond( - jax.lax.eq(state.buffer, state.player_y), - lambda _: new_player_y, - lambda _: state.buffer, - operand=None, - ) - final_player_y = state.buffer - return UpNDownState( - player_y=final_player_y, - player_speed=new_player_speed, - ball_x=state.ball_x, - ball_y=state.ball_y, - enemy_y=state.enemy_y, - enemy_speed=state.enemy_speed, - ball_vel_x=state.ball_vel_x, - ball_vel_y=state.ball_vel_y, - player_score=state.player_score, - enemy_score=state.enemy_score, - step_counter=state.step_counter, - acceleration_counter=new_acc_counter, - buffer=buffer, - ) - def _ball_step(self, state: UpNDownState, action) -> UpNDownState: - ball_x = state.ball_x + state.ball_vel_x - ball_y = state.ball_y + state.ball_vel_y + player_speed = state.player_car.speed - wall_bounce = jnp.logical_or( - ball_y <= self.consts.WALL_TOP_Y + self.consts.WALL_TOP_HEIGHT - self.consts.BALL_SIZE[1], - ball_y >= self.consts.WALL_BOTTOM_Y, + player_speed = jax.lax.cond( + jnp.logical_and(state.player_car.speed < self.consts.MAX_SPEED, up), + lambda s: s + 1, + lambda s: s, + operand=player_speed, ) - ball_vel_y = jnp.where(wall_bounce, -state.ball_vel_y, state.ball_vel_y) - player_paddle_hit = jnp.logical_and( - jnp.logical_and(self.consts.PLAYER_X <= ball_x, ball_x <= self.consts.PLAYER_X + self.consts.PLAYER_SIZE[0]), - state.ball_vel_x > 0, + player_speed = jax.lax.cond( + jnp.logical_and(state.player_car.speed > -self.consts.MAX_SPEED, down), + lambda s: s - 1, + lambda s: s, + operand=player_speed, ) - player_paddle_hit = jnp.logical_and( - player_paddle_hit, - jnp.logical_and( - state.player_y - self.consts.BALL_SIZE[1] <= ball_y, - ball_y <= state.player_y + self.consts.PLAYER_SIZE[1] + self.consts.BALL_SIZE[1], - ), - ) - enemy_paddle_hit = jnp.logical_and( - jnp.logical_and(self.consts.ENEMY_X <= ball_x, ball_x <= self.consts.ENEMY_X + self.consts.ENEMY_SIZE[0] - 1), - state.ball_vel_x < 0, + is_jumping = jnp.logical_or(jnp.logical_and(state.is_jumping, state.jump_cooldown > 0), jnp.logical_and(is_on_road, jnp.logical_and(player_speed > 0, state.jump_cooldown == 0))) + jump_cooldown = jax.lax.cond( + state.jump_cooldown > 0, + lambda s: s - 1, + lambda s: jnp.cond(jnp.logical_and(is_jumping), + lambda _: state.JUMP_FRAMES, + lambda _: 0, + operand=None), + operand=state.jump_cooldown, ) - enemy_paddle_hit = jnp.logical_and( - enemy_paddle_hit, - jnp.logical_and( - state.enemy_y - self.consts.BALL_SIZE[1] <= ball_y, - ball_y <= state.enemy_y + self.consts.ENEMY_SIZE[1] + self.consts.BALL_SIZE[1], - ), - ) - paddle_hit = jnp.logical_or(player_paddle_hit, enemy_paddle_hit) - - section_height = self.consts.PLAYER_SIZE[1] / 5 - - hit_position = jnp.where( - paddle_hit, - jnp.where( - player_paddle_hit, - jnp.where( - ball_y < state.player_y + section_height, - -2.0, - jnp.where( - ball_y < state.player_y + 2 * section_height, - -1.0, - jnp.where( - ball_y < state.player_y + 3 * section_height, - 0.0, - jnp.where( - ball_y < state.player_y + 4 * section_height, - 1.0, - 2.0, - ), - ), - ), - ), - jnp.where( - ball_y < state.enemy_y + section_height, - -2.0, - jnp.where( - ball_y < state.enemy_y + 2 * section_height, - -1.0, - jnp.where( - ball_y < state.enemy_y + 3 * section_height, - 0.0, - jnp.where( - ball_y < state.enemy_y + 4 * section_height, - 1.0, - 2.0, - ), - ), - ), - ), - ), - 0.0, - ) - paddle_speed = jnp.where( - player_paddle_hit, - state.player_speed, - jnp.where( - enemy_paddle_hit, - state.enemy_speed, - 0.0, - ), - ) - ball_vel_y = jnp.where(paddle_hit, hit_position, ball_vel_y) + ##check if player is on the the road + is_on_road = ~state.is_jumping - boost_triggered = jnp.logical_and( - player_paddle_hit, - jnp.logical_or( - jnp.logical_or(action == Action.LEFTFIRE, action == Action.RIGHTFIRE), - action == Action.FIRE, - ), - ) - player_max_hit = jnp.logical_and(player_paddle_hit, state.player_speed == self.consts.MAX_SPEED) - ball_vel_x = jnp.where( - jnp.logical_or(boost_triggered, player_max_hit), - state.ball_vel_x - + jnp.sign(state.ball_vel_x), - state.ball_vel_x, - ) + road_index_A, road_index_B = self._car_past_corner(state.player_car, state) - ball_vel_x = jnp.where( - paddle_hit, - -ball_vel_x, - ball_vel_x, + direction_change = jax.lax.cond( + jnp.logical_and(is_on_road, jnp.logical_or(jnp.logical_and(jnp.equal(road_index_A, state.player_car.road_index_A)) , state.player_car.current_road == 0), (jnp.logical_and(jnp.equal(road_index_B, state.player_car.road_index_B)) , state.player_car.current_road == 1) ), + lambda s: False, + lambda s: True, + operand=None, ) - return UpNDownState( - player_y=state.player_y, - player_speed=state.player_speed, - ball_x=ball_x.astype(jnp.int32), - ball_y=ball_y.astype(jnp.int32), - enemy_y=state.enemy_y, - enemy_speed=state.enemy_speed, - ball_vel_x=ball_vel_x.astype(jnp.int32), - ball_vel_y=ball_vel_y.astype(jnp.int32), - player_score=state.player_score, - enemy_score=state.enemy_score, - step_counter=state.step_counter, - acceleration_counter=state.acceleration_counter, - buffer=state.buffer, + + car_direction_x = jax.lax.cond( + direction_change, + lambda s: jax.lax.cond(state.player_car.current_road == 0, + lambda s: self.consts.FIRST_TRACK_CORNERS_X[road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_X[road_index_A], + lambda s: self.consts.SECOND_TRACK_CORNERS_X[road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[road_index_B], + operand=None), + lambda s: s, + operand=state.player_car.direction_x, ) + + is_landing = jnp.logical_and(state.jump_cooldown == 1, jump_cooldown == 0) - def _enemy_step(self, state: UpNDownState) -> UpNDownState: - should_move = state.step_counter % 8 != 0 + ##calculate new position with speed (TODO: calculate better speed) + player_y = state.player_car.position.y + player_speed + player_x = state.player_car.position.x + player_speed * car_direction_x - direction = jnp.sign(state.ball_y - state.enemy_y) + landing_in_Water, between_roads, road_A_x, road_B_x = self._landing_in_water(state, player_x, player_y) + landing_in_Water = jnp.logical_and(is_landing, landing_in_Water) + - new_y = state.enemy_y + (direction * self.consts.ENEMY_STEP_SIZE).astype(jnp.int32) - enemy_y = jax.lax.cond( - should_move, lambda _: new_y, lambda _: state.enemy_y, operand=None + current_road = jax.lax.cond( + landing_in_Water, + lambda s: 2, + lambda s: jax.lax.cond( + is_on_road, + lambda s: state.player_car.current_road, + lambda s: jax.lax.cond( + jnp.abs(player_x - road_A_x) < jnp.abs(player_x - road_B_x), + lambda s: 0, + lambda s: 1, + operand=None, + ), + operand=None, + ), + operand=None, ) return UpNDownState( - player_y=state.player_y, - player_speed=state.player_speed, - ball_x=state.ball_x, - ball_y=state.ball_y, - enemy_y=enemy_y.astype(jnp.int32), - enemy_speed=state.enemy_speed, - ball_vel_x=state.ball_vel_x, - ball_vel_y=state.ball_vel_y, - player_score=state.player_score, - enemy_score=state.enemy_score, - step_counter=state.step_counter, - acceleration_counter=state.acceleration_counter, - buffer=state.buffer, + score=state.score, + difficulty=state.difficulty, + road_index=state.road_index, + jump_cooldown=jump_cooldown, + is_jumping=is_jumping, + is_on_road=is_on_road, + player_car=Car( + position=EntityPosition( + x=player_x, + y=player_y, + width=state.player_car.position.width, + height=state.player_car.position.height, + ), + speed=player_speed, + direction_x=car_direction_x, + current_road=current_road, + road_index_A=road_index_A, + road_index_B=road_index_B, + type=state.player_car.type, + ), ) def _score_and_reset(self, state: UpNDownState) -> UpNDownState: @@ -405,26 +337,6 @@ def _score_and_reset(self, state: UpNDownState) -> UpNDownState: buffer=state.buffer, ) - def _reset_ball_after_goal(self, state_and_goal: Tuple[UpNDownState, bool]) -> Tuple[chex.Array, chex.Array, chex.Array, chex.Array]: - state, scored_right = state_and_goal - - ball_vel_y = jnp.where( - state.ball_y > self.consts.BALL_START_Y, - 1, - -1, - ).astype(jnp.int32) - - ball_vel_x = jnp.where( - scored_right, 1, -1 - ).astype(jnp.int32) - - return ( - self.consts.BALL_START_X.astype(jnp.int32), - self.consts.BALL_START_Y.astype(jnp.int32), - ball_vel_x.astype(jnp.int32), - ball_vel_y.astype(jnp.int32), - ) - def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: state = UpNDownState( player_y=jnp.array(96).astype(jnp.int32), From e657459e41a9b19727a2160bcc76a33760e1f3bf Mon Sep 17 00:00:00 2001 From: shaik05 Date: Thu, 13 Nov 2025 12:07:26 +0100 Subject: [PATCH 03/76] added basic interface template --- src/jaxatari/games/upndown_interface.py | 53 +++++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 src/jaxatari/games/upndown_interface.py diff --git a/src/jaxatari/games/upndown_interface.py b/src/jaxatari/games/upndown_interface.py new file mode 100644 index 000000000..68f8c76fe --- /dev/null +++ b/src/jaxatari/games/upndown_interface.py @@ -0,0 +1,53 @@ +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +from jaxatari.environment import JAXAtariAction as Action +from upndown import JaxUpNDown, UpNDownConstants # <-- your game file + +def visualize_frame(frame: jnp.ndarray): + """Render an RGB frame using matplotlib.""" + plt.imshow(frame.astype(jnp.uint8)) + plt.axis("off") + plt.show(block=False) + plt.pause(0.05) + plt.clf() + + +def main(): + # Initialize environment + env = JaxUpNDown(UpNDownConstants()) + + # Reset environment + obs, state = env.reset() + print("Initial observation:", obs) + + # Display initial render + frame = env.render(state) + visualize_frame(frame) + + # Create a random key for sampling actions + key = jax.random.PRNGKey(0) + + # Run for 50 steps + for step in range(50): + key, subkey = jax.random.split(key) + # Choose a random action from action space + action = jax.random.choice(subkey, jnp.arange(len(env.action_set))) + + obs, state, reward, done, info = env.step(state, action) + + # Render and display + frame = env.render(state) + visualize_frame(frame) + + print(f"Step {step}: action={env.action_set[int(action)]}, reward={reward}, done={done}") + + if bool(done): + print("Game over — resetting environment.") + obs, state = env.reset() + + plt.close() + +if __name__ == "__main__": + main() From 8df941a2b4de43dee5407435779dcf33b2e6de8f Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Thu, 13 Nov 2025 20:36:52 +0100 Subject: [PATCH 04/76] add partial backrounds and car sprites --- .../sprites/up_n_down/backround/backround1.npy | Bin 0 -> 102944 bytes .../up_n_down/backround/backround10.npy | Bin 0 -> 25548 bytes .../up_n_down/backround/backround11.npy | Bin 0 -> 37180 bytes .../up_n_down/backround/backround12.npy | Bin 0 -> 43808 bytes .../up_n_down/backround/backround13.npy | Bin 0 -> 45096 bytes .../sprites/up_n_down/backround/backround2.npy | Bin 0 -> 37948 bytes .../sprites/up_n_down/backround/backround3.npy | Bin 0 -> 37328 bytes .../sprites/up_n_down/backround/backround4.npy | Bin 0 -> 46944 bytes .../sprites/up_n_down/backround/backround5.npy | Bin 0 -> 34848 bytes .../sprites/up_n_down/backround/backround6.npy | Bin 0 -> 34624 bytes .../sprites/up_n_down/backround/backround7.npy | Bin 0 -> 47560 bytes .../sprites/up_n_down/backround/backround8.npy | Bin 0 -> 41132 bytes .../sprites/up_n_down/backround/backround9.npy | Bin 0 -> 31748 bytes .../games/sprites/up_n_down/player_car.npy | Bin 0 -> 1152 bytes 14 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/jaxatari/games/sprites/up_n_down/backround/backround1.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/backround/backround10.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/backround/backround11.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/backround/backround12.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/backround/backround13.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/backround/backround2.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/backround/backround3.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/backround/backround4.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/backround/backround5.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/backround/backround6.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/backround/backround7.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/backround/backround8.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/backround/backround9.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/player_car.npy diff --git a/src/jaxatari/games/sprites/up_n_down/backround/backround1.npy b/src/jaxatari/games/sprites/up_n_down/backround/backround1.npy new file mode 100644 index 0000000000000000000000000000000000000000..6c353b610ae66a21a8991791f5875675ae42bf48 GIT binary patch literal 102944 zcmeI4O^cjG6otpFn{2b3-6%{JK_lW1aHoji!c_<+VnB>c)Qum1!XG@Qk={H_IDO_; z)m!!U5yEih+`4t|t*7hl^o;rU>+in$=KBw>{Q zbNAcbi(j6BXh z;V>4)S~E5qQ|#{{&r0T)y+_u&H$D6O!`S-o$HJKY+1cLimC^l#u`t$}x!IXwzb&4X z%(>bOSpV#LwfTmz_1=$#u_xuF(M(TC-P`}mZaIDwW)jCG+ zceuw?&7*5__hY>3XKSDJy1H8b=>5W&{#jV{Ou61YW5zR|tC>$}H$UUUSQraq{`)b` zjIXQFn)SH4cfG8q^?v8vW8AZMO=>^xuQc6Q z*8XQnWX9E7vCwPZFc!vov31P5B4fwek6+z;g7L>Z5Bph3DYE9Mk zswX!;<5hm;GoPzfe)`q?)%C1L=`i-F{C6Jr%owlEh(6aV-FiLqDILbbn9HbQivAuL zV^3;k%4fc#K7OvI)+_&*=b>L!e)`;pzS3Op>H}kHjjDk@^U+tD>s_t#D_`xW)+=A- zSH8+mpZVx3&GoJhW2e8@p=PMgMENSe@>PEN%tv2ou6K193u9sI^yiiC>qyP8x9`^L z#>(zee++)l!drzuf*56yRp4OW6yZNi@)qY{@^jXQ7SG78Ws?Yk|eAV@8zc3cY z!dU(N7++_cX={D_JS|`U|9ykQJ$7lcXw4M!UD`ZVzhSKXvr^6c()=(M#(Fa!#@1tu z&t|2`kcy|z^{zgLG481}Yf$m@xxTgLeyw%udYR9yxAl7FQ#y?09#hOJhLj)1!dMuq z|34jShP{1$K3COlz1*+5p81qM++!-HT2pnss!#du{;lIxeXj4;$Nbgx%%^l1YyYg| z8LBgIedbr1`Bgl9u6K193u9!8+GUFIdt{6=ay4f}?dD@V^Hp^tW-1OnYdczcYW?ZdcE2&jD@i(qnE%KXXI+m zNUi7R+A3c8%(pjvxW`uSNu8nc*?;SJHy`U+U5ByLe~Z9nLd7e8b%s>^VJwU(Mymdb zjH!D;`J5T!=~JtC`mW}Ft@Y@7&S7kHO=_lYJ&b35<-76B=W6cPT92+Lj6J^R%+;By zc>1nZ`&a$Z^@Oo7rkJZTy#ixu2FmA*7*GFwnxCT`HS-<6fBZ_{)yF)ozK-7iz?iCi zbWNx-%CIAb2YV& z{V<-stC?@KJ}}mMPj@DaXMM_N{OnrwL*La~_aDaE|18NpQ!>N#^j*z-qjeYyW4)PM zfib>rd(UL+{Os4&te@-Y_trfaQ$5jlb?=#MouBhyJ@j48e7$uT+x}U}GbS_2XS~wP z*IS3NFc!xA-yiEev(fo^FL8hRu4ca0y2m}n8rie*8Lu?+?M<_P_DA2<%-33nvH6~r zJYzDWe8wxye64jD3uCRBV!k~v#@8u(+?!VY(sy<1d5vIy znXjtD*!<5*o-vtGKI4^UzN!vmVT{aGnd16AGS+)$oZ;45^-JH?)&90%j5V`oDBpb@t`^V!SwDSOGoR8k z++(bfJu9E_N;BWyH2Y_N^j*z-N{6x2ds5A8cAw0!Uiz+PKBdE07z<;kzdy#;jXNX8 zGvD5{>X*K&)x2iB$5=CaRzBmEX1={?_Rs$4yPEl^!&v>_PUlQ#(`1JA(RVfTQHQZG z<}$^2GDW=tW2{;ANZ-}mZ*Q9OV1M*o&3s`@|DA#}TV0bG_Cw#*%ooPO7@2UHVm$MO zu`os^$c(F*Z$-w~181Oo#w*Qyd(-Tn{n2-I7(2bkI77~SZ(8+B-_>C(jFB0>hUmLG zjD;~W6UO|XVRMGA=FBfmyYpiF9_}%9=Iq(kD*xVncV3JSV|iAp*9e($br=g{VNAdO za$i?&ygGx`eRm#i{2uNxw`Vtgb%s>^?mXQ1FqUVf%ZwWz#=@A(j2piKV{T7wyn2?b z?z{7F2Dn_b)7z<-nhSd5nrk{n?8LAog_Er5=U)8Vj&+vXs z)mZhbGg9lf_N((z>s5bheHeRkuV$oXytS|DullM#l|PJyF~vw2%lj?Gi1Mopsr4i7 zG1X($S3RoyYW>!}s=w;1{#1UoK8)#mP|dKnugX^&xazC=RDQKy z`C%-Kg|YR2rmfDf>Z^N1Dvf%*iKx`0Vg)Cyk0uf{8zQ4zFys| z7T-QzudbHco1G%%Lo1RVm@o;mzT5V;@PbE-5l`qcrux6?DgXP-;MIR zzkj~zdMUN`$?ZDVHs;mv_Iq)!z8bIXd|tm^+&J5=;ZWn9%`V$D3URrP}TcXv-V?U&U(m$f~6iT7h>t%N`A8t1~8&Doy4#QR|F z|7Ta1r*;j_W$p6fkXW!5tflUqLwM@iGbYx$<}49QU6ZjKFX2&FjbJUFYj&@6aZ(}{ ztOaYqn)Y`9Tfa)&dtEi`nv8Yrn!QAyfR_RBhqdz$;H`(%yh&efcq z^%(aw_fvcOj~!<}xtjg59^;8jRwuOv zYr&eaL3UJaby90htZ@d_i`w(~|KW0V-+uJTTCO+T-?vvopO5qB>b~a*)^?wveS5E- NFPAOuSI?iD{|}RAGR6P^ literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/backround/backround11.npy b/src/jaxatari/games/sprites/up_n_down/backround/backround11.npy new file mode 100644 index 0000000000000000000000000000000000000000..06b3d1675bf65113926d862950b295f1d1d9aa46 GIT binary patch literal 37180 zcmeI*Pl}vX6o&C~>K=rem4?nB!63Q-XOT>TIFiP+Bp^n*H4~FTAY=pXAS-Z(W!FHd z{&m6eY2^Q-6Qzg=AO=a0TV+dg>o)!FvTFV41awjc5D z)6HhH_wn%E!~Kwt4aos zpIg}+vond--v!T(dyM9h&C9uu_-tL-d-B+?Q_r(bMeEDCa_*F8$3Kf4%I0fyE;|F) zAJ6;nJjz|abhh6t+rRbY-0k~ZG&A$Hv5sWfd01c0m2;6JIumiUUpZIKMUF&{nMeDr z!8v=*ax^3JwQ+V{u6G_$Ud~;A=gfB>(ac=Gbhh6t+rRbZ&?r#a!yCG ztW(yTFVQ)B%Ca-C-h6Frzgc!3)|<=aTz%)OGZE{M>&@kIuAI}Eh;_*I=5je#&go3V zI^=qD`7StT&)FGSzcjYrEZeX3<`FN!IeW~qGqB!#ZEU|;b{^K7M_kU`{;njN+0t?y z$+G=gZys?uSI+5>tV2=mdU=h`*<&>$~N9XJ@%g(@h^R==4X4!dI zZ=U6HZhe2dr6XC^A?wYvT+WqqkyBBA7oD@`-ZOIlxv~9b*?z1yZ{<1M=j<`d&cJ%} zwXywX*?CxR-pb|N@;_U3B+EKvy?HB_bLE^4m2>qsGLaKH5^=Pjd2gQLJ{Qd-%AJAh z&DX~En`P%=y?Jjg=a%=`sv}v}A?wY1b2(Sem2I}i6WADhd$=|5X_qLp>X{mjSaa;}^!=WhRv^W5I2y%W|v{Z_Rt3 ow{^d|&rN@y^EGUBEUNcd#;x;->ZhL9KZlt4N&o-= literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/backround/backround12.npy b/src/jaxatari/games/sprites/up_n_down/backround/backround12.npy new file mode 100644 index 0000000000000000000000000000000000000000..ea76f49ffd82c144453aa08bd26c71bdfbb90aaa GIT binary patch literal 43808 zcmeI4O=}cE5Qf*Q7q4D~y{$?DM*IMe9y|mT5kx$NpeqI>N}^u;073Exyn4zn@E?xZ z!$9Z3Y)wt~Oi!(^Fzmc}tE;Q3pPdZ}=GT+wkDtE0H2XaJvUt5bes#2XFkgKAw6oZn zFW$UA`gpW|xcC0(^>X#~NBi%NmuFu;e!Kr+d8V)5zqK{rxqEMGe&@#4{M)?8zss}P ztVyrK-w&Ib*TKP+#&fY&uBAUUZPiQP(o^e?^`xw=?&*)Mz4K)J*!7m)&+3Q1rAybB zo3A!+*1YJmU(09w)V0+QeM?WRU-4G=`v>wsek`By)LdU%i#PE$CWFy3o3 zl1IG7;%#?#x5>G>FfnAz*u3S;YPELASwzSkeSYFi&$lt(XW~t~iMNZtM;UWA>-M_4 zceSZ;<$j*ls5#w}&$rYY;y(JkhM0*bXY~1rw>r-$4`XSpc=MfC^m(OcU*kUd{PcYD zyPuqml}+vdx|&Tc}1UBdiFK$qt8#zH{biEWo+Hg z=7#igc^k5}F1g8Dmo+&t(dU)Zdx^Et=O^Cu`>Yc0gRVKmn|SNWN{Kk}Cf-VL)HR2A z6K`EvDG?{$#9IlDy5*aprjecnzeO{&K;msF) ze&Vg)?-#tBl!_N`zVnJcuTtFf%@cio(VlOdci(f{ly#@4H=CSm>^Y3h8~e+d)oLyC z{ck$j{KT8@`+=$PEZ%(Q7j0hR&3C?>SM=u_d1Ze?>2-e_Zb*;)axeN`&3vUbc`Kct z`PjeZGv3n7w`rRFvp@P?&3vV`c$@a$Q{{P+AIoRFrJ1j^7H{G$Ww10~%(rR0rL$Yz zv)|~N{ZX?&`d-a^xw!VNYjpur5=CibT+w9L%ix2ukcK-i2aXwSiJa5)l+GjpXvmR>kCf-seES@Ys&GoPQXuPop z_HX%&w>0zB)~uKHSiU!&`MjF@iMR95H_nsus;#a1ET8pT@$|i#`-wO4mNG;>Qm@ZF z^NBa{mU2WsQm@ZF^NF{9ylwwDXlh(MbJk$_-gxHoYVJqg=$F>4m-#K<8_#@R&HZX? z*3Wt@-y6?-Ud{c)+xq7l=T%$tyjhRsd*hkUtGSi`ng``Q+tz^6UH2WICQ)-ps$v z&#%vJ<`=W&`>)Tx+|KUazx{mvb#|wpym&qxADtXe$Hz~nwz8;Q7qs8*N{`ab= z>+0&!V%x>_Zj0T;W4@@~rN*`x*4q}l^cY{dHfJlI4?QQvh4NWH8qfPn)3tEMzVRHC z&v>Pok6PWYe3g%W?YdOXR9|_HjHe&f%tsx~$9R>Ge(l&deE^NTJP{ zIZMG6a}I^Gwm*MG^Re$k)8;JAJR^rgK=irAc=nn;bu^y$Q!`)RbYJEf z&n>EXuIm11d^8{P*RIXk>bc2tjB=skmCyQp)8@>aReSk{is&;YxB*SIm^MUIWuSGZ2kMQ=yQkh?4|MIl`x2;t@ z`bxLGzTElD+4^UVHhWKQn47-RyuUVW&dga3X3d#7GiPfCqR$n^=ji3O8frcCmEKmX zM6c9&rus~uy;VNrQ){&z`bwu>Z`=9J*?MNS-P_H4*;nN=zH6=4OJC`(_oM2UGS5^W>2qJoXMERMt(U&iUGGQLH)s3JTh&+PlN-j< zSDN>irp;MDI3tJVtodBoCu;Uh`HU}JtNqYdy0ravHP2L^=(BIiXME{e?T5b7rR~?8 zHP6@F=R|H;pT5$(ziVyI`pFqNG-s>-XUM)N-F4qm*H`do(Y~m7wXfQJ zwg0H!bwA8m8P82}5ame4GkIvO*vJ6l;8DU%~?0|DSGaTBjrcywH1gh6opMo!z;KmNT#qvC@3jV8W2LFph6ZgB2ttf8)*1>30mIOB*n`1hpe-g z-I?7PpGd3C@wt0v=H72-6f3`9y?y!m-Mz`T$@kgCYV%<|dp4c@__~o4o& z_4}Lk#p>?!7t5>7>h|-^r{(9>t$z4)F`q8Z&gRo6kLJ^#)BpInKbcIn&2|0f`L?F( z{Cu+Azqq!2FU9X~em&mqUmRoKJB8mH&hGBDtylGye&3q)WPbX!YqNH^_mPJ#n(Qza z{o1uzGi$jRubDO0ko!t*&H0>~Hf!njHL`qaY;nD2&8(4$RNAbWwNz{|uURu|WFnO| zYi2DKTg+?L%o>?UrOldIOT`xR*325ebBJoziaMH)@yu7dHfzWCHPw5v6V)nzZN6DE zYq^*%$nLKl+PM7o<6mTrytcEuXJtJc7KDA1=U*>Pd}>F_@((~ z&8+2O)~uN|v$kWPsYUb*VtkXQxhubkYkODu=tp(id5kr#q#NVs)t@5$2V*F&QR8#>?oh{N;BWk+N_zi zTnv`P8f(J$nto|ot%rV84_%L(tVL_Xc-Bh!j32vJ>!BajL$BYg4R?mJ_hd);j8~fZ za%;0@)^agd5^H>KbJb|5anyS0M|JLXrDQF-cE+<7%4dA(TCJCUROhabSsTw8%HES5 zYe)(>p(xMnXhl%Cu`BQF`jEvKI3cGsvh*CI(7Zb z+IY`U_MYr0pYcjFU*FoSnYD(&zS&|vvsMmk>|Mi1b60*7*Y?hQU3C}M+Rph{9OuqB z`qqc+leb+Z-KC;8H=tnj4b=79gtQ}*pD_eZttd+|edr4-L&v>Po zFRCM~sktbhb7MSx>)3ms>PbJUyY|DZ#roN`-YK6aI~xn>Dl6mboZfj5llb Ru*TkAUOw3Nc`<)f{{<4M{W1Um literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/backround/backround3.npy b/src/jaxatari/games/sprites/up_n_down/backround/backround3.npy new file mode 100644 index 0000000000000000000000000000000000000000..ee7c16619682eb58b3f417f6eb3b8f43549aaa7d GIT binary patch literal 37328 zcmeI5&1w`u6ouQZ8((2|qcDO2+5O~2>J8s?aSBi?ls?=AM^9&`on7ed^Z32?P$K3%|BkP zzOK$L-(RiHm)++t&OWb~*PpLHoqbte>xa*tEM`Z?$BWt1M~m67*)6{AH%-&_ugkwD z?a*~`@t}=)4Xq__crOm=Pt3jQ<~n4}eKCE`y8C@`+s|{#eVdcM?`JZn{Ss^Y_o)Vk z#oGV(a1HxeQVpl4O&fBB)apIgLh|g!hUBWT;b-ua`zO}I_phX9K&;)Z{hpFdu_o5W z;%`c}Sd;f1%JDhX!rY`!f_@I@_BLh@#ai8GYEEZ?SQBeQJnCGln~#wF9pdH?k85)b zeUJ8@Yu4KIz1owzVsAck-hMNlJ+*ws4{5P>bH9fAmAgkayqfiKJbjNAYhta>6WQr| z-93Ge7Heg*Mi#u9tZ+R25bb4+YgxI`_iC;uz58r<`^k9r)AAYb(PC|P?>xQBA(m`- zHEZB_`XO4ZiM2jIVy$c*`g<9Azu9cskc;bcF8VQb=()yRtla5)HP^dq+S@P2vtO3a z_z*4D$~rU2hF6mnj;9|}i#4&<=Pf2%oU1g}`hDr{Sx;@6^?Eh?!SVFR)UjD(u2yaI zy_)OQroH`OJo{n!jE||sTHfzavf$QBwkEz94Uhhz{;nie?pOemwHTK5pEq$-%da<=PAI38u%V&H}TC9!x z4ka62O;$Lbe$HB~iM3RGmd2X(T+{dJ(7y9;dA*wJb3FZ$bndLNmsaoTdo|bdYVUfC z=X#dU_?)#^+x3cQTqxR0ncxyiTC2O&^ zd#+h5aIYEf)z*CUOV(mdtfk^ttcf+TcDvuP@*T!L(oe0eeCXGtOF!4R7i7os89%nR z@}Xao7Hh-r(Aeye9nMd`CN0*)S}LAvV~u^Y&JFsp^_$<1+y2G*`(Ew$WUaP(t-Ylm zTeF{>pMK3+td08)B|BoRX}8R=H>GQ9Z|TR@>>uZ+U$fpaw*6jntzB*0>(sjT{tngm E53}cxUjP6A literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/backround/backround4.npy b/src/jaxatari/games/sprites/up_n_down/backround/backround4.npy new file mode 100644 index 0000000000000000000000000000000000000000..65d9d322f490b9d5430979b2e88fc2e795ab12b9 GIT binary patch literal 46944 zcmeI4Pm3H?6vf-E8^1#DT47uSqJm2ol8qvQ3s*9jkpv`>8Fk~IFX7hjI;I#N3_rR& z@63Hw(_M1{P36?N_rCk?`*n5bN%GIv-+lGX_wVg~-~DlTd3E#S_2KjV;m_Zm9WM5V zpWa^odVTrkhqu=+ucrH7UjBS@_3r-7tIJ=m-sulN`}AV}?33pg`{y5B?El(7;Me=R z-R^dJz4`Zxyb3dT{rZDj=QVoXT-&W*%~$I^l(TAI)qJ%t)_3(|*H`P)=lvR3J^VNFcESyb@R5=Z2$2sC$sny&V?`l=w^|}7q`D(qbKF_#s zp2OPu+rK})eaf%K+16*UHdkD4>+@Cn3uh1eTqQ?Jt8>`8AI`$r8jL>`&Nxf;{Hgj} zZ?ty%FEdvua<}e6U!0D_Cp=c!r8<`ID6Qi(Kt6(^ZnR#`~K0b z$M{y~sa)IhY-{`Fxw=}-!>z}7Rex)KI6MBism__)xLVcUnjg->*&0l5g){eg!g$V_ zKD8Uqd}`KHn)ynzo{Fc>d{_5!#&wluA1a&ps-RjI;tM#h+bA8R-Z=B8f?S8edYJTgZUEO+&Z*|&L`7N$pRXlxHw{_;O)q2(Z zxxVJ^hqJq%H*oHA^J@21JbhQUb>^_s@m1o>}}~4CldE z`mR>>=(}3=!+I*7el?%YVAjTJ^(vDxQ8dpY^E2SvZ@R2xrItoX)vW^ZlfJ#?R^YJ({|5J;u{_HS3+aZqKt* z>(iY(#UKTFqVKfRnS-IrUR@!Xg48Nakv{m^%{+yBz_*-toI=5v#C=D8}L@k+Cv(yAVP zRZschES!<4V_I>f{Bz}u`{HcqkJ5ABtGV~tFZ1cUn)Rqh*JFL^#u@KZkM7H@$9UBr z{ZrHKei*N`+ONu|uj(m3oIUK{yQ{e>U!5a;SG()0`l_Dt!&x{}j7`4cNcrc&8TY~2 z&_6e=_Cw#*Y96eo;$6R&GuCzcV!W%X>ofnTHh_Yjw!L+|@_)c0_z4)@D=`mScZrFDCrjb6vyH{;ol@)>{b zTJ4vzLd}St!p(8`mSDje&OsiKR0>qu?s%CI&}y#d_h)nWe{PhG2d(RcOe`G&L8__>pNuH=UK^j*z*RUOX4*)ayI zTroeKohN5JN3x=P#w*QwuI}Yb^`U(B#d!MB9{*~d^j%#&k8tLlXLa67<&zt(L*Lb` z=jw14&U%=0xng`c+XiPm*B*A97}j_7TAXoDZl8(w-VdC2RafU9&b;%h&U>x=Dp$;} z_8rc`8JRk!!`U`Dsh%tIU%c4eR$i><>fYyBrN=4P?JMzP-#?Fie&KBFd7kR~KJU5{Kb%H_*h zkzdJf8WjS48ml&HMAWA0Bl-x}UR))#lTB_Nt%#`hGH-_p{H}>znoR z>f`nLVzrxpz5Kda-Og{mEWfR8?Wc>E^ZsP9nD;N9&HLYd!k@=o*KOsw`ggkZJZER! z*7JDwX6&h}_Pf57uAy+Y8@+pZI2+=u-t)B{-+KMYn(r=ei9fOy|6OgVZ&S#nHTFFB zT5D&S0v|cg&{}`U+H}65$SGO7`=(-TXy@ggtdX@^@6_exldb3R=56_gdh*=GdUIvG z@7}B3KeFcQ-Znhph2l3o)7dT_tN|jjT`r zBdMc@o0Qm`tl|EtHIO*eqkaFqkUFJ~L)Vk6-S6Mf_4n+T%9zA?_D+cv&$Pt(`x&p0*!ZYxf&!WdZ%Q-lM-}YkJx` zStDy2W3sNXq^*ycHK~D&+F4s)mvf|!*7I5SvFf9LP1Z0M>)FvC^~Pr7d91eeyruOq zAF|fYH`LmD^vB+6HpXkVrl+lwHL|8LChHnY+WM4OlbXn=oxS$DRWBhw>8^(9`w(Bp3?ect==~jduz4HTE5;{Ymff9dy!iY^R?Qg^U1A0<$Dc# zm#cQ~{yg8NJeZHwPT8}in0?k-%Ii=r*5081Ar-M4 c&l+39kRR%y9g2&cM?J5N`&#w9@!Ed>0ih7rLI3~& literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/backround/backround6.npy b/src/jaxatari/games/sprites/up_n_down/backround/backround6.npy new file mode 100644 index 0000000000000000000000000000000000000000..0642e3cff1f1ba56d35e673c11980f4dfe92951d GIT binary patch literal 34624 zcmeI2zi!k(5XMbQ!z;Kmplc!}6qFPw4Ty$<3Z2LyB1H+h0~+8dXsPm|CN8Y3zHs)A zcV_?ji)7u-d3HSW&Cg>e%AdFI-@N@IePPy6m=+x_}}(rw!A^UeO-etZ4# zW`D68zJI;FI_&PgKYZDK-QC$w&t7cWle5!Jd-`nC{%#liJZ_q%ACK$5ullLu_V#%{ zb>QA9-{IVx|J-?uFZIX$b2i48+S2%;{+M5Ayt(~T=SRkdPs|JLsW`(qZGVgxYR~0A zcRuv`rSG59XX?C!vo-59Xjppvq4`g7Ge4B z=FF{g$=X-{TK&$?n?B@lb*Y)N5RHbtu1;%QYwAbNj^DFuWXr9YoK63oZ`4Xtr!nT5 zJ~j2LX`X57vL>F!S}pY>XZO#|QSY^Wk~2OhW3W%o$Qe2FW57Zu%lpqF-`rxx)UV|C zvzR*N#)w8fG4&&7>E^KpDJ-w?T=&M*ucUbvQNPlzLOV0dyE!8QP{-(-9su%ie=AH9A!@8xmrcNcs zlIn&2nl+O%zgeoOd8}B-jnvl6lbn&WWBie``~T%}HPY0j#8|EtOV*Y(=QPijXyfj& z)TxBnt`eQafec zF*%d!g8nIW%Hi6mhty74134?_{Y=Ug`ja!hkNGfw+(>PoYgqYko^hU8vE{}kXXMP~ z!D>y;$e9&eZd`Ik&Ria>*5s@l&X57=?*ZtK@hoji&af`2ZCQ5>v88&Ve@qR?8RuCH z-pCm_BWL9@5Tkc^4SX(2ZH%XBYtA!iUFfe_*BWD?o>JRVGjdkX%#?D4{^YD47A$*- z+(>Q9+&ItbWmw7?`fG9*!WrfzwKeNttcUlDv>)iNsVO;o*dI-OQ^rDWq_$@6_Db@iul~K1vVY`Ezn{6@BPlui2e0iZE&u=k literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/backround/backround7.npy b/src/jaxatari/games/sprites/up_n_down/backround/backround7.npy new file mode 100644 index 0000000000000000000000000000000000000000..259a34f810ef7a6bb7bc68472c654c527964e2f2 GIT binary patch literal 47560 zcmeI2&uSDw7{u4BPqDXEiHAT$@R)-)#fyk12`1t}5=qpHdhr#!iibRbZ@6X%i}A;V z+M4e9x0JA$n)>z?2AxWu2&Up{;B`p)9Z;_K@8dA8T?eqR>Z@GGRcK+%7 z==9Cm`SD3}{prz%i<7JC7w?Zgo?PjBk00+X_a7bXEf4PREx#{2{JFbWEY{8E;j2IS z^tcC?mk-uAKiuz|UYDPlH#NO2XIwkiLEqKP+cnMo;QHvhnt4;x;cS?5oZOHD`mSc) z)O0utXG1U<&cazZ8-{@)_F26L*@LQAKJ&)bUC%SFlk1@GYUa&dtMzkzs$Tib8(W97 zVa`mK3stZDvAN>>;Vhhyi9U2Vn+9h*pLnn7yPA1J>sp*~-CPHKS2J&Jx^%tnc-D`- zU*Rn0*$@ndvv3y9{{P+c&3C~NdsB0co4q$bGn+oOThH;-%ZXWY%>u?s%YBSZA7tX?28yB^&g|l#0o2jyYFpcsuBm?ME6ug3dgU{(wjPQzu9NGa?`q~vO>;f$ zkG|5S ztXGTtFj?@h%8eOJ3&u|AxIvu&6oH?9t6 z9%qNYf2_NFcrTejpStgQ&PUDpUA^H*=u>k(`qXYc$5S)U)!{6hk*PNAa>e?&aK^sd_QJWwyuLMi<9zg~ z-TK<&RUZ4Jo+@W*UG!beedt@WKh8&=+O4lWUgfbr>TtH3?@h%8eOHqsrNdb`a~W&u z6<74bnaA1bufuhhkGhx4WWDlPPhC1*Zx+S;AJw4VJlPx-7@nt9aW?E3%ga$jm|_db=@ zvw!9(pY=*Jk2;)%vxW)9k@CqEbvO%W;q2dcK|W7>Ug*1;d9}6bhx3)zt32ha`ROx{ zI_Fu=vxW(B<7#q69nQj8IJ^D5$>)vF4}DiNuT-n~nO|C8J5TjPpYzZ4JX7nU?`rN- zsaEqdzqGz~p6Z7_=MQJs?@hH&+&5PYd_DpkGyZLW*#;3=u@kD<-2*TcXc=mXALthm*MRCUF+_Hs^@)hwVPjB z&;CmD)Oh933tr1 Syy8gt;VhhO!?5B=`F{cQg9AcR-q*Nn{ogFNp)N1F=Df6|#s?7KlVPumQHY05|XnxS@%KMk-p< z)z#IH>Bk>wZyM|*RWB>Onn^w7>Vd?wXbefw7V_;0f^BhaIxn3Bv)i0gb zJcpKIq&m&{fid(k$#-Fn9qAkx17lz;&hHP>@0`s~@*GyIDRB+oJL>_jGg$9aK!nv z$QW0zWR26Lhc$^toR5!2we~Tte#x4TMlJeEG~#@GGy!9Aeyj4)yaZp4AIX|e4PXq6 zNtiNyT4c#LXPAEE&m>*HN%GUGc2WAnc` zQXNAyY9W?tp|L|WXtj@pXwpKgh9epiQ-e|&qxEX+m>SiXOQRW$iCHrkYx6utdC}Nl z4CgNw0~)Sqj4Qv-dMka*r{)%YDIXd;r6#2?rqPAQq|~T}+*AvVow7zS*2Z~E!xfDI zV>o}g7@)jpY*+q}^;-K_NX;$8YB-`XIW<^{Fd@t%f$fYMPs}2r(8!LE3=PjIHEBrneok?iZNQR#`dkFHGhp}G^Qjyz!>^iY97HD z7z1NvG7$T@(D$6ij(tBd^Tq088eM2i%$jP*!<4^f3)bG-gR=eD~HeM(fqszW3RhzeY0}QUz!(@y%_SJi!`R_) zehj%%bC5E>l(7(fPO&L#s3CXH`hhX@vDDvDFb2lJSRMva>x_5h;5;w&tWga9de#r` T$~=7%jG>Rk=v7>fo*DZKUKXOw literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/backround/backround9.npy b/src/jaxatari/games/sprites/up_n_down/backround/backround9.npy new file mode 100644 index 0000000000000000000000000000000000000000..ca78ccff5956da850b147227da97d14abebac977 GIT binary patch literal 31748 zcmeI*L2evX5XSLg%N>wrH-OB-65enCcA#t!VudVX!~zjy1Dhn1%~f4F=1^@rbX|Gaz0pWj?xZLe=`uC`x)akYKD{fJ+m zZZ@08>H6x&{~mi@Z{B?Nxa4K6daFn5emUCTTeqH8wqE_zBX++$!H zORu+{EL)HJ)w8UhTv>ZMpKE5vEU0H$Ke@72S(`GE&6evYpSgQwt+F;{!|Y^v@|nAj zlr?+E&SLcSZ{L3Txa76=Jl4$Gl5^U7tv#z#pEvsaDr<*#=&7?5onKiy{#lJat0)V- z_0i{7pKHheS?Qf6{moXSyFOz5dUMaL*{i*%XIZ~XW9MOi>RHyWH&@o`bFFvACjZJ> zWo^nr@B7GE^PSOqR%`p)pKLwqS=Mi9UNdXu5anaqQjxj@>u6PQO+@W!CJ`UevR!-)P+0 zZ+5@ym!4N;ZQOTglnwPPSJsaIe>F?bYHh#PY-IPle(8QM`E#xJ?3R8O`dPmmt&ep5 zBGxazC!qT4gQD*raE( z<@%AbX6@4qKK8EPxpCB2RImG8ubyT7B0i_AS*yKAy6Yp>@7%bxe)rp-dY1KzxUzOR z?@*t$*@<-5N337Om9@%Rl&LvA%9i>_S+i$*QO~k|r{<{sY<>EqE$>*FkYqj@C zcYVbAotl@fS3mnxk9eP*_t)qrSJn>y=H+Zo&F??Q5B)w{pC0a4kGQgS$WAmXb!UHS zwmvyqU$kF6;!Dn&vvn4c?)r%JTbp~=r@#HF%VteozmeXd)*khyo@LKlo1=PL>(O67 zb-A*3_#TF{$+ELwo1=PL>(O67b-A)uSu-}A{Jzqar zfAuJ9A6mcPv%mCiWzUk=_V@fOx7M58AJv!jbA8s|^_8{5_dK$*b(XSzz1j2S*7Nm~ z^;fT~Ro2YN9Ge~4ed?99L)M&4G&|SJ?(fZ>FSnkrU&NkQS*xs>i8+q4rLIq9t+Hk& z=GZJntWRaFvSudcILel~K2IrY_MT;DDL?g}rx4B4`9!?W=5$5gm!?OvQNOhQTl;Z7 n5l=bW-&@ZfZGSWy_0jt4U9a_+pCvZek%)KCkM literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/player_car.npy b/src/jaxatari/games/sprites/up_n_down/player_car.npy new file mode 100644 index 0000000000000000000000000000000000000000..54646d43afbe3bd8d6bd6b7344293ff74d6267ec GIT binary patch literal 1152 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1JQ);NLqoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$-W;zNWY@(^7P^&-=;9|f8|A9i-Fs4c*kpl+~{Ktop%tX?KVB*sYlEb4P zEJsK`Odg9l$YztmhB+Qx4@?{v4O0)3hw))FOdJ=DOC3y{EE*O*FdE%*m^h3E%9CXe cE`2a_(P@}^ba5De5Xb|a00y|iV^GTj02)|$&;S4c literal 0 HcmV?d00001 From 03be31fe311130786649e84df22c73039748212d Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Fri, 14 Nov 2025 23:02:04 +0100 Subject: [PATCH 05/76] create first running version doing nothing --- .../games/{UpNDown.py => jax_upndown.py} | 255 +++--------------- .../background1.npy} | Bin .../background10.npy} | Bin .../background11.npy} | Bin .../background12.npy} | Bin .../background13.npy} | Bin .../background2.npy} | Bin .../background3.npy} | Bin .../background4.npy} | Bin .../background5.npy} | Bin .../background6.npy} | Bin .../background7.npy} | Bin .../background8.npy} | Bin .../background9.npy} | Bin 14 files changed, 39 insertions(+), 216 deletions(-) rename src/jaxatari/games/{UpNDown.py => jax_upndown.py} (56%) rename src/jaxatari/games/sprites/up_n_down/{backround/backround1.npy => background/background1.npy} (100%) rename src/jaxatari/games/sprites/up_n_down/{backround/backround10.npy => background/background10.npy} (100%) rename src/jaxatari/games/sprites/up_n_down/{backround/backround11.npy => background/background11.npy} (100%) rename src/jaxatari/games/sprites/up_n_down/{backround/backround12.npy => background/background12.npy} (100%) rename src/jaxatari/games/sprites/up_n_down/{backround/backround13.npy => background/background13.npy} (100%) rename src/jaxatari/games/sprites/up_n_down/{backround/backround2.npy => background/background2.npy} (100%) rename src/jaxatari/games/sprites/up_n_down/{backround/backround3.npy => background/background3.npy} (100%) rename src/jaxatari/games/sprites/up_n_down/{backround/backround4.npy => background/background4.npy} (100%) rename src/jaxatari/games/sprites/up_n_down/{backround/backround5.npy => background/background5.npy} (100%) rename src/jaxatari/games/sprites/up_n_down/{backround/backround6.npy => background/background6.npy} (100%) rename src/jaxatari/games/sprites/up_n_down/{backround/backround7.npy => background/background7.npy} (100%) rename src/jaxatari/games/sprites/up_n_down/{backround/backround8.npy => background/background8.npy} (100%) rename src/jaxatari/games/sprites/up_n_down/{backround/backround9.npy => background/background9.npy} (100%) diff --git a/src/jaxatari/games/UpNDown.py b/src/jaxatari/games/jax_upndown.py similarity index 56% rename from src/jaxatari/games/UpNDown.py rename to src/jaxatari/games/jax_upndown.py index af62fe461..4d63a6455 100644 --- a/src/jaxatari/games/UpNDown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -24,6 +24,7 @@ class UpNDownConstants(NamedTuple): FIRST_TRACK_CORNERS_Y: chex.Array = jnp.array([20, 50, 80, 100]) #get actual values SECOND_TRACK_CORNERS_X: chex.Array = jnp.array([20, 50, 80, 100]) #get actual values SECOND_TRACK_CORNERS_Y: chex.Array = jnp.array([20, 50, 80, 100]) #get actual values + PLAYER_SIZE: Tuple[int, int] = (4, 16) @@ -46,7 +47,6 @@ class Car(NamedTuple): class UpNDownState(NamedTuple): score: chex.Array difficulty: chex.Array - road_index: chex.Array jump_cooldown: chex.Array is_jumping: chex.Array is_on_road: chex.Array @@ -57,8 +57,6 @@ class UpNDownState(NamedTuple): class UpNDownObservation(NamedTuple): player: EntityPosition - enemies: jnp.ndarray - score: jnp.ndarray class Collectible(NamedTuple): position: EntityPosition @@ -90,8 +88,8 @@ def __init__(self, consts: UpNDownConstants = None, reward_funcs: list[callable] @partial(jax.jit, static_argnums=(0,)) def _car_past_corner(self, car: Car, state: UpNDownState) -> chex.Array: - direction_change_A = jnp.logical_or(jnp.logical_and(car.speed > 0, car.position.y + car.speed > self.consts.FIRST_TRACK_CORNERS_Y[car.road_index+1]), jnp.logical_and(car.speed < 0, car.position.y + car.speed < self.consts.FIRST_TRACK_CORNERS_Y[car.road_index])) - direction_change_B = jnp.logical_or(jnp.logical_and(car.speed > 0, car.position.y + car.speed > self.consts.SECOND_TRACK_CORNERS_Y[car.road_index+1]), jnp.logical_and(car.speed < 0, car.position.y + car.speed < self.consts.SECOND_TRACK_CORNERS_Y[car.road_index])), + direction_change_A = jnp.logical_or(jnp.logical_and(car.speed > 0, car.position.y + car.speed > self.consts.FIRST_TRACK_CORNERS_Y[car.road_index_A+1]), jnp.logical_and(car.speed < 0, car.position.y + car.speed < self.consts.FIRST_TRACK_CORNERS_Y[car.road_index_A])) + direction_change_B = jnp.logical_or(jnp.logical_and(car.speed > 0, car.position.y + car.speed > self.consts.SECOND_TRACK_CORNERS_Y[car.road_index_B+1]), jnp.logical_and(car.speed < 0, car.position.y + car.speed < self.consts.SECOND_TRACK_CORNERS_Y[car.road_index_B])) road_index_A = jax.lax.cond(jnp.logical_and(direction_change_A, car.speed > 0), lambda s: s + 1, @@ -151,12 +149,12 @@ def _landing_in_water(self, state: UpNDownState, new_position_x: chex.Array, new distance_to_road_B = jnp.abs(new_position_x - road_B_x) landing_in_Water = jnp.logical_and(distance_to_road_A > self.consts.LANDING_ZONE, distance_to_road_B > self.consts.LANDING_ZONE) between_roads = jnp.logical_and(new_position_x > jnp.minimum(road_A_x, road_B_x), new_position_x < jnp.maximum(road_A_x, road_B_x)) - return landing_in_Water, between_roads + return landing_in_Water, between_roads, road_A_x, road_B_x def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: up = jnp.logical_or(action == Action.UP, action == Action.UPFIRE) down = jnp.logical_or(action == Action.DOWN, action == Action.DOWNFIRE) - jump = jnp.logical_or(action == Action.FIRE, action == Action.UPFIRE, action == Action.DOWNFIRE) + jump = jnp.logical_or(action == Action.FIRE, jnp.logical_or(action == Action.UPFIRE, action == Action.DOWNFIRE)) @@ -177,12 +175,12 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: ) - is_jumping = jnp.logical_or(jnp.logical_and(state.is_jumping, state.jump_cooldown > 0), jnp.logical_and(is_on_road, jnp.logical_and(player_speed > 0, state.jump_cooldown == 0))) + is_jumping = jnp.logical_or(jnp.logical_and(state.is_jumping, state.jump_cooldown > 0), jnp.logical_and(state.is_on_road, jnp.logical_and(player_speed > 0, state.jump_cooldown == 0))) jump_cooldown = jax.lax.cond( state.jump_cooldown > 0, lambda s: s - 1, - lambda s: jnp.cond(jnp.logical_and(is_jumping), - lambda _: state.JUMP_FRAMES, + lambda s: jax.lax.cond(is_jumping, + lambda _: self.consts.JUMP_FRAMES, lambda _: 0, operand=None), operand=state.jump_cooldown, @@ -197,7 +195,7 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: road_index_A, road_index_B = self._car_past_corner(state.player_car, state) direction_change = jax.lax.cond( - jnp.logical_and(is_on_road, jnp.logical_or(jnp.logical_and(jnp.equal(road_index_A, state.player_car.road_index_A)) , state.player_car.current_road == 0), (jnp.logical_and(jnp.equal(road_index_B, state.player_car.road_index_B)) , state.player_car.current_road == 1) ), + jnp.logical_and(is_on_road, jnp.logical_or(jnp.logical_and(jnp.equal(road_index_A, state.player_car.road_index_A) , state.player_car.current_road == 0), (jnp.logical_and(jnp.equal(road_index_B, state.player_car.road_index_B) , state.player_car.current_road == 1)))) , lambda s: False, lambda s: True, operand=None, @@ -243,7 +241,6 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: return UpNDownState( score=state.score, difficulty=state.difficulty, - road_index=state.road_index, jump_cooldown=jump_cooldown, is_jumping=is_jumping, is_on_road=is_on_road, @@ -263,95 +260,28 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: ), ) - def _score_and_reset(self, state: UpNDownState) -> UpNDownState: - player_goal = state.ball_x < 4 - enemy_goal = state.ball_x > 156 - ball_reset = jnp.logical_or(enemy_goal, player_goal) - - player_score = jax.lax.cond( - player_goal, - lambda s: s + 1, - lambda s: s, - operand=state.player_score, - ) - enemy_score = jax.lax.cond( - enemy_goal, - lambda s: s + 1, - lambda s: s, - operand=state.enemy_score, - ) - - current_values = ( - state.ball_x.astype(jnp.int32), - state.ball_y.astype(jnp.int32), - state.ball_vel_x.astype(jnp.int32), - state.ball_vel_y.astype(jnp.int32), - ) - ball_x_final, ball_y_final, ball_vel_x_final, ball_vel_y_final = jax.lax.cond( - ball_reset, - lambda x: self._reset_ball_after_goal((state, enemy_goal)), - lambda x: x, - operand=current_values, - ) - - step_counter = jax.lax.cond( - ball_reset, - lambda s: jnp.array(0), - lambda s: s + 1, - operand=state.step_counter, - ) - - enemy_y_final = jax.lax.cond( - ball_reset, - lambda s: self.consts.BALL_START_Y.astype(jnp.int32), - lambda s: state.enemy_y.astype(jnp.int32), - operand=None, - ) - - ball_x_final = jax.lax.cond( - step_counter < 60, - lambda s: self.consts.BALL_START_X.astype(jnp.int32), - lambda s: s, - operand=ball_x_final, - ) - ball_y_final = jax.lax.cond( - step_counter < 60, - lambda s: self.consts.BALL_START_Y.astype(jnp.int32), - lambda s: s, - operand=ball_y_final, - ) - - return UpNDownState( - player_y=state.player_y, - player_speed=state.player_speed, - ball_x=ball_x_final, - ball_y=ball_y_final, - enemy_y=enemy_y_final, - enemy_speed=state.enemy_speed, - ball_vel_x=ball_vel_x_final, - ball_vel_y=ball_vel_y_final, - player_score=player_score, - enemy_score=enemy_score, - step_counter=step_counter, - acceleration_counter=state.acceleration_counter, - buffer=state.buffer, - ) def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: state = UpNDownState( - player_y=jnp.array(96).astype(jnp.int32), - player_speed=jnp.array(0.0).astype(jnp.int32), - ball_x=jnp.array(78).astype(jnp.int32), - ball_y=jnp.array(115).astype(jnp.int32), - enemy_y=jnp.array(115).astype(jnp.int32), - enemy_speed=jnp.array(0.0).astype(jnp.int32), - ball_vel_x=self.consts.BALL_SPEED[0].astype(jnp.int32), - ball_vel_y=self.consts.BALL_SPEED[1].astype(jnp.int32), - player_score=jnp.array(0).astype(jnp.int32), - enemy_score=jnp.array(0).astype(jnp.int32), - step_counter=jnp.array(0).astype(jnp.int32), - acceleration_counter=jnp.array(0).astype(jnp.int32), - buffer=jnp.array(96).astype(jnp.int32), + score=0, + difficulty=self.consts.DIFFICULTIES[0], + jump_cooldown=0, + is_jumping=False, + is_on_road=True, + player_car=Car( + position=EntityPosition( + x=50, + y=50, + width=self.consts.PLAYER_SIZE[0], + height=self.consts.PLAYER_SIZE[1], + ), + speed=0, + direction_x=0, + current_road=0, + road_index_A=0, + road_index_B=0, + type=0, + ), ) initial_obs = self._get_observation(state) @@ -361,9 +291,6 @@ def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: def step(self, state: UpNDownState, action: chex.Array) -> Tuple[UpNDownObservation, UpNDownState, float, bool, UpNDownInfo]: previous_state = state state = self._player_step(state, action) - state = self._enemy_step(state) - state = self._ball_step(state, action) - state = self._score_and_reset(state) done = self._get_done(state) env_reward = self._get_reward(previous_state, state) @@ -378,31 +305,13 @@ def render(self, state: UpNDownState) -> jnp.ndarray: def _get_observation(self, state: UpNDownState): player = EntityPosition( - x=jnp.array(self.consts.PLAYER_X), - y=state.player_y, + x=jnp.array(state.player_car.position.x), + y=state.player_car.position.y, width=jnp.array(self.consts.PLAYER_SIZE[0]), height=jnp.array(self.consts.PLAYER_SIZE[1]), ) - - enemy = EntityPosition( - x=jnp.array(self.consts.ENEMY_X), - y=state.enemy_y, - width=jnp.array(self.consts.ENEMY_SIZE[0]), - height=jnp.array(self.consts.ENEMY_SIZE[1]), - ) - - ball = EntityPosition( - x=state.ball_x, - y=state.ball_y, - width=jnp.array(self.consts.BALL_SIZE[0]), - height=jnp.array(self.consts.BALL_SIZE[1]), - ) return UpNDownObservation( player=player, - enemy=enemy, - ball=ball, - score_player=state.player_score, - score_enemy=state.enemy_score, ) @partial(jax.jit, static_argnums=(0,)) @@ -412,16 +321,6 @@ def obs_to_flat_array(self, obs: UpNDownObservation) -> jnp.ndarray: obs.player.y.flatten(), obs.player.height.flatten(), obs.player.width.flatten(), - obs.enemy.x.flatten(), - obs.enemy.y.flatten(), - obs.enemy.height.flatten(), - obs.enemy.width.flatten(), - obs.ball.x.flatten(), - obs.ball.y.flatten(), - obs.ball.height.flatten(), - obs.ball.width.flatten(), - obs.score_player.flatten(), - obs.score_enemy.flatten() ] ) @@ -436,20 +335,6 @@ def observation_space(self) -> spaces: "width": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), "height": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), }), - "enemy": spaces.Dict({ - "x": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), - "y": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), - "width": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), - "height": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), - }), - "ball": spaces.Dict({ - "x": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), - "y": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), - "width": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), - "height": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), - }), - "score_player": spaces.Box(low=0, high=21, shape=(), dtype=jnp.int32), - "score_enemy": spaces.Box(low=0, high=21, shape=(), dtype=jnp.int32), }) def image_space(self) -> spaces.Box: @@ -462,20 +347,15 @@ def image_space(self) -> spaces.Box: @partial(jax.jit, static_argnums=(0,)) def _get_info(self, state: UpNDownState, ) -> UpNDownInfo: - return UpNDownInfo(time=state.step_counter) + return UpNDownInfo(time=1) @partial(jax.jit, static_argnums=(0,)) def _get_reward(self, previous_state: UpNDownState, state: UpNDownState): - return (state.player_score - state.enemy_score) - ( - previous_state.player_score - previous_state.enemy_score - ) + return state.score @partial(jax.jit, static_argnums=(0,)) def _get_done(self, state: UpNDownState) -> bool: - return jnp.logical_or( - jnp.greater_equal(state.player_score, 21), - jnp.greater_equal(state.enemy_score, 21), - ) + return jnp.logical_not(True) class UpNDownRenderer(JAXGameRenderer): def __init__(self, consts: UpNDownConstants = None): @@ -487,13 +367,10 @@ def __init__(self, consts: UpNDownConstants = None): #downscale=(84, 84) ) self.jr = render_utils.JaxRenderingUtils(self.config) - # 1. Create procedural assets for both walls - wall_sprite_top = self._create_wall_sprite(self.consts.WALL_TOP_HEIGHT) - wall_sprite_bottom = self._create_wall_sprite(self.consts.WALL_BOTTOM_HEIGHT) # 2. Update asset config to include both walls - asset_config = self._get_asset_config(wall_sprite_top, wall_sprite_bottom) - sprite_path = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/UpNDown" + asset_config = self._get_asset_config() + sprite_path = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/" # 3. Make a single call to the setup function ( @@ -504,25 +381,11 @@ def __init__(self, consts: UpNDownConstants = None): self.FLIP_OFFSETS ) = self.jr.load_and_setup_assets(asset_config, sprite_path) - def _create_wall_sprite(self, height: int) -> jnp.ndarray: - """Procedurally creates an RGBA sprite for a wall of given height.""" - wall_color_rgba = (*self.consts.SCORE_COLOR, 255) # e.g., (236, 236, 236, 255) - wall_shape = (height, self.consts.WIDTH, 4) - wall_sprite = jnp.tile(jnp.array(wall_color_rgba, dtype=jnp.uint8), (*wall_shape[:2], 1)) - return wall_sprite - - def _get_asset_config(self, wall_sprite_top: jnp.ndarray, wall_sprite_bottom: jnp.ndarray) -> list: + def _get_asset_config(self) -> list: """Returns the declarative manifest of all assets for the game, including both wall sprites.""" return [ - {'name': 'background', 'type': 'background', 'file': 'background.npy'}, - {'name': 'player', 'type': 'single', 'file': 'player.npy'}, - {'name': 'enemy', 'type': 'single', 'file': 'enemy.npy'}, - {'name': 'ball', 'type': 'single', 'file': 'ball.npy'}, - {'name': 'player_digits', 'type': 'digits', 'pattern': 'player_score_{}.npy'}, - {'name': 'enemy_digits', 'type': 'digits', 'pattern': 'enemy_score_{}.npy'}, - # Add the procedurally created sprites to the manifest - {'name': 'wall_top', 'type': 'procedural', 'data': wall_sprite_top}, - {'name': 'wall_bottom', 'type': 'procedural', 'data': wall_sprite_bottom}, + {'name': 'background', 'type': 'background', 'file': 'background/background1.npy'}, + {'name': 'player', 'type': 'single', 'file': 'player_car.npy'}, ] @partial(jax.jit, static_argnums=(0,)) @@ -530,46 +393,6 @@ def render(self, state): raster = self.jr.create_object_raster(self.BACKGROUND) player_mask = self.SHAPE_MASKS["player"] - raster = self.jr.render_at(raster, self.consts.PLAYER_X, state.player_y, player_mask) - - enemy_mask = self.SHAPE_MASKS["enemy"] - raster = self.jr.render_at(raster, self.consts.ENEMY_X, state.enemy_y, enemy_mask) - - ball_mask = self.SHAPE_MASKS["ball"] - raster = self.jr.render_at(raster, state.ball_x, state.ball_y, ball_mask) - - # --- Stamp Walls and Score (using the same color/ID) --- - score_color_tuple = self.consts.SCORE_COLOR # (236, 236, 236) - score_id = self.COLOR_TO_ID[score_color_tuple] - - # Draw walls (using separate sprites for top and bottom) - raster = self.jr.render_at(raster, 0, self.consts.WALL_TOP_Y, self.SHAPE_MASKS["wall_top"]) - raster = self.jr.render_at(raster, 0, self.consts.WALL_BOTTOM_Y, self.SHAPE_MASKS["wall_bottom"]) - - # Stamp Score using the label utility - player_digits = self.jr.int_to_digits(state.player_score, max_digits=2) - enemy_digits = self.jr.int_to_digits(state.enemy_score, max_digits=2) - - # Note: The logic for single/double digits is complex for a jitted function. - player_digit_masks = self.SHAPE_MASKS["player_digits"] # Assumes single color - enemy_digit_masks = self.SHAPE_MASKS["enemy_digits"] # Assumes single color - - is_player_single_digit = state.player_score < 10 - player_start_index = jax.lax.select(is_player_single_digit, 1, 0) - player_num_to_render = jax.lax.select(is_player_single_digit, 1, 2) - player_render_x = jax.lax.select(is_player_single_digit, - 120 + 16 // 2, - 120) - - raster = self.jr.render_label_selective(raster, player_render_x, 3, player_digits, player_digit_masks, player_start_index, player_num_to_render, spacing=16) - - is_enemy_single_digit = state.enemy_score < 10 - enemy_start_index = jax.lax.select(is_enemy_single_digit, 1, 0) - enemy_num_to_render = jax.lax.select(is_enemy_single_digit, 1, 2) - enemy_render_x = jax.lax.select(is_enemy_single_digit, - 10 + 16 // 2, - 10) - - raster = self.jr.render_label_selective(raster, enemy_render_x, 3, enemy_digits, enemy_digit_masks, enemy_start_index, enemy_num_to_render, spacing=16) + raster = self.jr.render_at(raster, state.player_car.position.x, state.player_car.position.y, player_mask) return self.jr.render_from_palette(raster, self.PALETTE) \ No newline at end of file diff --git a/src/jaxatari/games/sprites/up_n_down/backround/backround1.npy b/src/jaxatari/games/sprites/up_n_down/background/background1.npy similarity index 100% rename from src/jaxatari/games/sprites/up_n_down/backround/backround1.npy rename to src/jaxatari/games/sprites/up_n_down/background/background1.npy diff --git a/src/jaxatari/games/sprites/up_n_down/backround/backround10.npy b/src/jaxatari/games/sprites/up_n_down/background/background10.npy similarity index 100% rename from src/jaxatari/games/sprites/up_n_down/backround/backround10.npy rename to src/jaxatari/games/sprites/up_n_down/background/background10.npy diff --git a/src/jaxatari/games/sprites/up_n_down/backround/backround11.npy b/src/jaxatari/games/sprites/up_n_down/background/background11.npy similarity index 100% rename from src/jaxatari/games/sprites/up_n_down/backround/backround11.npy rename to src/jaxatari/games/sprites/up_n_down/background/background11.npy diff --git a/src/jaxatari/games/sprites/up_n_down/backround/backround12.npy b/src/jaxatari/games/sprites/up_n_down/background/background12.npy similarity index 100% rename from src/jaxatari/games/sprites/up_n_down/backround/backround12.npy rename to src/jaxatari/games/sprites/up_n_down/background/background12.npy diff --git a/src/jaxatari/games/sprites/up_n_down/backround/backround13.npy b/src/jaxatari/games/sprites/up_n_down/background/background13.npy similarity index 100% rename from src/jaxatari/games/sprites/up_n_down/backround/backround13.npy rename to src/jaxatari/games/sprites/up_n_down/background/background13.npy diff --git a/src/jaxatari/games/sprites/up_n_down/backround/backround2.npy b/src/jaxatari/games/sprites/up_n_down/background/background2.npy similarity index 100% rename from src/jaxatari/games/sprites/up_n_down/backround/backround2.npy rename to src/jaxatari/games/sprites/up_n_down/background/background2.npy diff --git a/src/jaxatari/games/sprites/up_n_down/backround/backround3.npy b/src/jaxatari/games/sprites/up_n_down/background/background3.npy similarity index 100% rename from src/jaxatari/games/sprites/up_n_down/backround/backround3.npy rename to src/jaxatari/games/sprites/up_n_down/background/background3.npy diff --git a/src/jaxatari/games/sprites/up_n_down/backround/backround4.npy b/src/jaxatari/games/sprites/up_n_down/background/background4.npy similarity index 100% rename from src/jaxatari/games/sprites/up_n_down/backround/backround4.npy rename to src/jaxatari/games/sprites/up_n_down/background/background4.npy diff --git a/src/jaxatari/games/sprites/up_n_down/backround/backround5.npy b/src/jaxatari/games/sprites/up_n_down/background/background5.npy similarity index 100% rename from src/jaxatari/games/sprites/up_n_down/backround/backround5.npy rename to src/jaxatari/games/sprites/up_n_down/background/background5.npy diff --git a/src/jaxatari/games/sprites/up_n_down/backround/backround6.npy b/src/jaxatari/games/sprites/up_n_down/background/background6.npy similarity index 100% rename from src/jaxatari/games/sprites/up_n_down/backround/backround6.npy rename to src/jaxatari/games/sprites/up_n_down/background/background6.npy diff --git a/src/jaxatari/games/sprites/up_n_down/backround/backround7.npy b/src/jaxatari/games/sprites/up_n_down/background/background7.npy similarity index 100% rename from src/jaxatari/games/sprites/up_n_down/backround/backround7.npy rename to src/jaxatari/games/sprites/up_n_down/background/background7.npy diff --git a/src/jaxatari/games/sprites/up_n_down/backround/backround8.npy b/src/jaxatari/games/sprites/up_n_down/background/background8.npy similarity index 100% rename from src/jaxatari/games/sprites/up_n_down/backround/backround8.npy rename to src/jaxatari/games/sprites/up_n_down/background/background8.npy diff --git a/src/jaxatari/games/sprites/up_n_down/backround/backround9.npy b/src/jaxatari/games/sprites/up_n_down/background/background9.npy similarity index 100% rename from src/jaxatari/games/sprites/up_n_down/backround/backround9.npy rename to src/jaxatari/games/sprites/up_n_down/background/background9.npy From eaa29de59385ac9f90ec22e1aab9c0aa636dd0e0 Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sun, 16 Nov 2025 13:34:26 +0100 Subject: [PATCH 06/76] use black backround and top and bottom wall sprites --- src/jaxatari/games/jax_upndown.py | 35 +++++++++++++++++++++++++------ 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 4d63a6455..e1ed585b7 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -270,8 +270,8 @@ def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: is_on_road=True, player_car=Car( position=EntityPosition( - x=50, - y=50, + x=30, + y=105, width=self.consts.PLAYER_SIZE[0], height=self.consts.PLAYER_SIZE[1], ), @@ -367,9 +367,13 @@ def __init__(self, consts: UpNDownConstants = None): #downscale=(84, 84) ) self.jr = render_utils.JaxRenderingUtils(self.config) + + background = self._createBackgroundSprite(self.config.game_dimensions) + top_block = self._createBackgroundSprite((25, self.config.game_dimensions[1])) + bottom_block = self._createBackgroundSprite((16, self.config.game_dimensions[1])) # 2. Update asset config to include both walls - asset_config = self._get_asset_config() + asset_config = self._get_asset_config(background, top_block, bottom_block) sprite_path = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/" # 3. Make a single call to the setup function @@ -381,18 +385,37 @@ def __init__(self, consts: UpNDownConstants = None): self.FLIP_OFFSETS ) = self.jr.load_and_setup_assets(asset_config, sprite_path) - def _get_asset_config(self) -> list: + def _createBackgroundSprite(self, dimensions: Tuple[int, int]) -> jnp.ndarray: + """Creates a procedural background sprite for the game.""" + height, width = dimensions + color = (0, 0, 0, 255) # RGBA for wall color + shape = (height, width, 4) # Height, Width, RGBA channels + sprite = jnp.tile(jnp.array(color, dtype=jnp.uint8), (*shape[:2], 1)) + return sprite + + def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.ndarray, bottomBlockSprite: jnp.ndarray) -> list: """Returns the declarative manifest of all assets for the game, including both wall sprites.""" return [ - {'name': 'background', 'type': 'background', 'file': 'background/background1.npy'}, + {'name': 'background', 'type': 'background', 'data': backgroundSprite}, + {'name': 'road1', 'type': 'single', 'file': 'background/background1.npy'}, {'name': 'player', 'type': 'single', 'file': 'player_car.npy'}, + {'name': 'wall_top', 'type': 'procedural', 'data': topBlockSprite}, + {'name': 'wall_bottom', 'type': 'procedural', 'data': bottomBlockSprite}, ] @partial(jax.jit, static_argnums=(0,)) def render(self, state): raster = self.jr.create_object_raster(self.BACKGROUND) + road1_mask = self.SHAPE_MASKS["road1"] + raster = self.jr.render_at(raster, 10, 25, road1_mask) player_mask = self.SHAPE_MASKS["player"] - raster = self.jr.render_at(raster, state.player_car.position.x, state.player_car.position.y, player_mask) + raster = self.jr.render_at(raster, state.player_car.position.x, 105, player_mask) + + wall_top_mask = self.SHAPE_MASKS["wall_top"] + raster = self.jr.render_at(raster, 0, 0, wall_top_mask) + + wall_bottom_mask = self.SHAPE_MASKS["wall_bottom"] + raster = self.jr.render_at(raster, 0, 210 - wall_bottom_mask.shape[0], wall_bottom_mask) return self.jr.render_from_palette(raster, self.PALETTE) \ No newline at end of file From e70d91380ce17ccaf45fb4ec1a1108c427e4a6fd Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sun, 16 Nov 2025 14:57:43 +0100 Subject: [PATCH 07/76] add first movment of player and map to game --- src/jaxatari/games/jax_upndown.py | 41 +++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index e1ed585b7..5d57a07e2 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -20,11 +20,12 @@ class UpNDownConstants(NamedTuple): LANDING_ZONE: int = 15 FIRST_ROAD_LENGTH: int = 4 SECOND_ROAD_LENGTH: int = 4 - FIRST_TRACK_CORNERS_X: chex.Array = jnp.array([20, 50, 80, 100]) #get actual values - FIRST_TRACK_CORNERS_Y: chex.Array = jnp.array([20, 50, 80, 100]) #get actual values + FIRST_TRACK_CORNERS_X: chex.Array = jnp.array([30, 80, 140, 100]) #get actual values + FIRST_TRACK_CORNERS_Y: chex.Array = jnp.array([105, 80, 25, 100]) #get actual values SECOND_TRACK_CORNERS_X: chex.Array = jnp.array([20, 50, 80, 100]) #get actual values SECOND_TRACK_CORNERS_Y: chex.Array = jnp.array([20, 50, 80, 100]) #get actual values PLAYER_SIZE: Tuple[int, int] = (4, 16) + INITIAL_ROAD_POS_Y: int = 25 @@ -51,6 +52,7 @@ class UpNDownState(NamedTuple): is_jumping: chex.Array is_on_road: chex.Array player_car: Car + road_diff: chex.Array @@ -175,7 +177,7 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: ) - is_jumping = jnp.logical_or(jnp.logical_and(state.is_jumping, state.jump_cooldown > 0), jnp.logical_and(state.is_on_road, jnp.logical_and(player_speed > 0, state.jump_cooldown == 0))) + is_jumping = jnp.logical_or(jnp.logical_and(state.is_jumping, state.jump_cooldown > 0), jnp.logical_and(state.is_on_road, jnp.logical_and(player_speed >= 0, jnp.logical_and(state.jump_cooldown == 0, jump)))) jump_cooldown = jax.lax.cond( state.jump_cooldown > 0, lambda s: s - 1, @@ -202,21 +204,24 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: ) + car_direction_x = jax.lax.cond(state.player_car.current_road == 0, + lambda s: self.consts.FIRST_TRACK_CORNERS_X[road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_X[road_index_A], + lambda s: self.consts.SECOND_TRACK_CORNERS_X[road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[road_index_B], + operand=None), car_direction_x = jax.lax.cond( - direction_change, - lambda s: jax.lax.cond(state.player_car.current_road == 0, - lambda s: self.consts.FIRST_TRACK_CORNERS_X[road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_X[road_index_A], - lambda s: self.consts.SECOND_TRACK_CORNERS_X[road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[road_index_B], - operand=None), - lambda s: s, - operand=state.player_car.direction_x, + car_direction_x[0] > 0, + lambda s: 1, + lambda s: -1, + operand=car_direction_x, ) + is_landing = jnp.logical_and(state.jump_cooldown == 1, jump_cooldown == 0) ##calculate new position with speed (TODO: calculate better speed) player_y = state.player_car.position.y + player_speed player_x = state.player_car.position.x + player_speed * car_direction_x + jax.debug.print("Player X: {}, Player Y: {}, car_direction_x: {}", player_x, player_y, car_direction_x) landing_in_Water, between_roads, road_A_x, road_B_x = self._landing_in_water(state, player_x, player_y) landing_in_Water = jnp.logical_and(is_landing, landing_in_Water) @@ -238,6 +243,9 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: ), operand=None, ) + road_diff = state.road_diff + player_speed + jax.debug.print("road_diff: {}", road_diff) + #jax.debug.print("Player X: {}, Player Y: {}, on road: {}, jumping: {}, speed: {}, road index A: {}, road index B: {}, current road: {}", player_x, player_y, is_on_road, is_jumping, player_speed, road_index_A, road_index_B, current_road) return UpNDownState( score=state.score, difficulty=state.difficulty, @@ -258,6 +266,7 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: road_index_B=road_index_B, type=state.player_car.type, ), + road_diff=state.road_diff + player_speed, ) @@ -282,6 +291,7 @@ def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: road_index_B=0, type=0, ), + road_diff=0, ) initial_obs = self._get_observation(state) @@ -371,9 +381,10 @@ def __init__(self, consts: UpNDownConstants = None): background = self._createBackgroundSprite(self.config.game_dimensions) top_block = self._createBackgroundSprite((25, self.config.game_dimensions[1])) bottom_block = self._createBackgroundSprite((16, self.config.game_dimensions[1])) + temp_pointer = self._createBackgroundSprite((1, 1)) # 2. Update asset config to include both walls - asset_config = self._get_asset_config(background, top_block, bottom_block) + asset_config = self._get_asset_config(background, top_block, bottom_block, temp_pointer) sprite_path = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/" # 3. Make a single call to the setup function @@ -393,7 +404,7 @@ def _createBackgroundSprite(self, dimensions: Tuple[int, int]) -> jnp.ndarray: sprite = jnp.tile(jnp.array(color, dtype=jnp.uint8), (*shape[:2], 1)) return sprite - def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.ndarray, bottomBlockSprite: jnp.ndarray) -> list: + def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.ndarray, bottomBlockSprite: jnp.ndarray, tempPointer: jnp.ndarray) -> list: """Returns the declarative manifest of all assets for the game, including both wall sprites.""" return [ {'name': 'background', 'type': 'background', 'data': backgroundSprite}, @@ -401,13 +412,14 @@ def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.n {'name': 'player', 'type': 'single', 'file': 'player_car.npy'}, {'name': 'wall_top', 'type': 'procedural', 'data': topBlockSprite}, {'name': 'wall_bottom', 'type': 'procedural', 'data': bottomBlockSprite}, + {'name': 'tempPointer', 'type': 'procedural', 'data': tempPointer}, ] @partial(jax.jit, static_argnums=(0,)) def render(self, state): raster = self.jr.create_object_raster(self.BACKGROUND) road1_mask = self.SHAPE_MASKS["road1"] - raster = self.jr.render_at(raster, 10, 25, road1_mask) + raster = self.jr.render_at_clipped(raster, 10, self.consts.INITIAL_ROAD_POS_Y + state.road_diff, road1_mask) player_mask = self.SHAPE_MASKS["player"] raster = self.jr.render_at(raster, state.player_car.position.x, 105, player_mask) @@ -418,4 +430,7 @@ def render(self, state): wall_bottom_mask = self.SHAPE_MASKS["wall_bottom"] raster = self.jr.render_at(raster, 0, 210 - wall_bottom_mask.shape[0], wall_bottom_mask) + wall_bottom_mask = self.SHAPE_MASKS["tempPointer"] + raster = self.jr.render_at(raster, 140, 26, wall_bottom_mask) + return self.jr.render_from_palette(raster, self.PALETTE) \ No newline at end of file From 2b9ea9f1c72b17af63660642262fc54388faee65 Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sun, 16 Nov 2025 16:46:33 +0100 Subject: [PATCH 08/76] add new parts of the map --- src/jaxatari/games/jax_upndown.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 5d57a07e2..0e58ecafb 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -395,6 +395,7 @@ def __init__(self, consts: UpNDownConstants = None): self.COLOR_TO_ID, self.FLIP_OFFSETS ) = self.jr.load_and_setup_assets(asset_config, sprite_path) + self.road_sizes = self._get_road_sprite_sizes() def _createBackgroundSprite(self, dimensions: Tuple[int, int]) -> jnp.ndarray: """Creates a procedural background sprite for the game.""" @@ -403,12 +404,24 @@ def _createBackgroundSprite(self, dimensions: Tuple[int, int]) -> jnp.ndarray: shape = (height, width, 4) # Height, Width, RGBA channels sprite = jnp.tile(jnp.array(color, dtype=jnp.uint8), (*shape[:2], 1)) return sprite + + def _get_road_sprite_sizes(self) -> list: + """Returns the sizes of the road sprites.""" + sizes = [] + for file in os.listdir(f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/background/"): + sprite = jnp.load(f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/background/{file}") + sizes.append(sprite.shape[0]) + jax.debug.print("Road sizes: {}", sizes) + return sizes def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.ndarray, bottomBlockSprite: jnp.ndarray, tempPointer: jnp.ndarray) -> list: """Returns the declarative manifest of all assets for the game, including both wall sprites.""" + roads = [] + for x in range(13): + roads.append(f"background/background{x+1}.npy") return [ {'name': 'background', 'type': 'background', 'data': backgroundSprite}, - {'name': 'road1', 'type': 'single', 'file': 'background/background1.npy'}, + {'name': 'road', 'type': 'group', 'files': roads}, {'name': 'player', 'type': 'single', 'file': 'player_car.npy'}, {'name': 'wall_top', 'type': 'procedural', 'data': topBlockSprite}, {'name': 'wall_bottom', 'type': 'procedural', 'data': bottomBlockSprite}, @@ -418,8 +431,14 @@ def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.n @partial(jax.jit, static_argnums=(0,)) def render(self, state): raster = self.jr.create_object_raster(self.BACKGROUND) - road1_mask = self.SHAPE_MASKS["road1"] + + road1_mask = self.SHAPE_MASKS["road"][0] raster = self.jr.render_at_clipped(raster, 10, self.consts.INITIAL_ROAD_POS_Y + state.road_diff, road1_mask) + diff = 0 + for i in range(12): + road1_mask = self.SHAPE_MASKS["road"][i+1] + diff += self.road_sizes[i+1] + raster = self.jr.render_at_clipped(raster, 10, self.consts.INITIAL_ROAD_POS_Y + state.road_diff - diff, road1_mask) player_mask = self.SHAPE_MASKS["player"] raster = self.jr.render_at(raster, state.player_car.position.x, 105, player_mask) From 958e021d4422b4b21f706adefddabdfcbd108b40 Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sun, 23 Nov 2025 21:10:27 +0100 Subject: [PATCH 09/76] changes to player movement and road selection --- src/jaxatari/games/jax_upndown.py | 223 ++++++++++++++++++++---------- 1 file changed, 149 insertions(+), 74 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 0e58ecafb..6651a57b2 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -15,13 +15,13 @@ class UpNDownConstants(NamedTuple): FRAME_SKIP: int = 4 DIFFICULTIES: chex.Array = jnp.array([0, 1, 2, 3, 4, 5]) ACTION_REPEAT_PROBS: float = 0.25 - MAX_SPEED: int = 4 + MAX_SPEED: int = 1 JUMP_FRAMES: int = 10 LANDING_ZONE: int = 15 FIRST_ROAD_LENGTH: int = 4 SECOND_ROAD_LENGTH: int = 4 - FIRST_TRACK_CORNERS_X: chex.Array = jnp.array([30, 80, 140, 100]) #get actual values - FIRST_TRACK_CORNERS_Y: chex.Array = jnp.array([105, 80, 25, 100]) #get actual values + FIRST_TRACK_CORNERS_X: chex.Array = jnp.array([30, 80, 140, 80]) #get actual values + FIRST_TRACK_CORNERS_Y: chex.Array = jnp.array([105, 80, 25, 0]) #get actual values SECOND_TRACK_CORNERS_X: chex.Array = jnp.array([20, 50, 80, 100]) #get actual values SECOND_TRACK_CORNERS_Y: chex.Array = jnp.array([20, 50, 80, 100]) #get actual values PLAYER_SIZE: Tuple[int, int] = (4, 16) @@ -52,7 +52,7 @@ class UpNDownState(NamedTuple): is_jumping: chex.Array is_on_road: chex.Array player_car: Car - road_diff: chex.Array + step_counter: chex.Array @@ -89,59 +89,29 @@ def __init__(self, consts: UpNDownConstants = None, reward_funcs: list[callable] self.obs_size = 3*4+1+1 @partial(jax.jit, static_argnums=(0,)) - def _car_past_corner(self, car: Car, state: UpNDownState) -> chex.Array: - direction_change_A = jnp.logical_or(jnp.logical_and(car.speed > 0, car.position.y + car.speed > self.consts.FIRST_TRACK_CORNERS_Y[car.road_index_A+1]), jnp.logical_and(car.speed < 0, car.position.y + car.speed < self.consts.FIRST_TRACK_CORNERS_Y[car.road_index_A])) - direction_change_B = jnp.logical_or(jnp.logical_and(car.speed > 0, car.position.y + car.speed > self.consts.SECOND_TRACK_CORNERS_Y[car.road_index_B+1]), jnp.logical_and(car.speed < 0, car.position.y + car.speed < self.consts.SECOND_TRACK_CORNERS_Y[car.road_index_B])) - - road_index_A = jax.lax.cond(jnp.logical_and(direction_change_A, car.speed > 0), - lambda s: s + 1, - lambda s: s, - operand=car.road_index_A, - ) - road_index_A = jax.lax.cond(jnp.logical_and(direction_change_A, car.speed < 0), - lambda s: s - 1, - lambda s: s, - operand=car.road_index_A, - ) - - road_index_B = jax.lax.cond(jnp.logical_and(direction_change_B, car.speed > 0), - lambda s: s + 1, - lambda s: s, - operand=car.road_index_B, - ) - road_index_B = jax.lax.cond(jnp.logical_and(direction_change_B, car.speed < 0), - lambda s: s - 1, - lambda s: s, - operand=car.road_index_B, - ) - current_road_length_A = self.consts.FIRST_ROAD_LENGTH - current_road_length_B = self.consts.SECOND_ROAD_LENGTH - - road_index_A = jax.lax.cond(road_index_A < 0, - lambda s: current_road_length_A - 1, - lambda s: s, - operand=road_index_A, - ) - - road_index_A = jax.lax.cond(road_index_A >= current_road_length_A, - lambda s: 0, - lambda s: s, - operand=road_index_A, - ) - - road_index_B = jax.lax.cond(road_index_B < 0, - lambda s: current_road_length_B - 1, - lambda s: s, - operand=road_index_B, - ) - - road_index_B = jax.lax.cond(road_index_B >= current_road_length_B, - lambda s: 0, - lambda s: s, - operand=road_index_B, + def _getSlopeAndB(self, state: UpNDownState) -> chex.Array: + trackx, tracky, roadIndex = jax.lax.cond( + state.player_car.current_road == 0, + lambda s: (self.consts.FIRST_TRACK_CORNERS_X, self.consts.FIRST_TRACK_CORNERS_Y, state.player_car.road_index_A), + lambda s: (self.consts.SECOND_TRACK_CORNERS_X, self.consts.SECOND_TRACK_CORNERS_Y, state.player_car.road_index_B), + operand=None,) + slope = jax.lax.cond( + trackx[roadIndex+1] - trackx[roadIndex] != 0, + lambda s: (tracky[roadIndex+1] - tracky[roadIndex]) / (trackx[roadIndex+1] - trackx[roadIndex]), + lambda s: jnp.inf, + operand=None, ) + b = tracky[roadIndex] - slope * trackx[roadIndex] + return slope, b + + @partial(jax.jit, static_argnums=(0,)) + def _isOnLine(self, state: UpNDownState, new_position_x: chex.Array, new_position_y: chex.Array, player_speed: chex.Array) -> chex.Array: + slope, b = self._getSlopeAndB(state) + jax.debug.print("slope: {}, b: {}", slope, b) + isOnLine = jnp.less_equal(jnp.abs(jnp.round(jnp.subtract(new_position_y, slope * new_position_x + b))), player_speed) - return road_index_A, road_index_B + jax.debug.print("isOnLine: {}", jnp.subtract(new_position_y, slope * new_position_x + b)) + return isOnLine @partial(jax.jit, static_argnums=(0,)) def _landing_in_water(self, state: UpNDownState, new_position_x: chex.Array, new_position_y: chex.Array) -> chex.Array: @@ -219,8 +189,39 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: is_landing = jnp.logical_and(state.jump_cooldown == 1, jump_cooldown == 0) ##calculate new position with speed (TODO: calculate better speed) - player_y = state.player_car.position.y + player_speed - player_x = state.player_car.position.x + player_speed * car_direction_x + player_y = jax.lax.cond( + state.step_counter % 8 == 4, + lambda s: jax.lax.cond( + is_jumping, + lambda s: state.player_car.position.y + player_speed * -1, + lambda s: jax.lax.cond( + self._isOnLine(state, state.player_car.position.x, s + player_speed * -1, player_speed), + lambda s: s + player_speed * -1, + lambda s: s, + operand=state.player_car.position.y, + ), + operand=state.player_car.position.y), + lambda s: state.player_car.position.y, + operand=None, + ) + player_x = jax.lax.cond( + state.step_counter % 8 == 0, + lambda s: jax.lax.cond( + is_jumping, + lambda s: s + player_speed * car_direction_x, + lambda s: jax.lax.cond( + self._isOnLine(state, s + player_speed * car_direction_x, player_y, player_speed), + lambda s: s + player_speed * car_direction_x, + lambda s: s, + operand=state.player_car.position.x, + ), + operand=state.player_car.position.x), + lambda s: s, + operand=state.player_car.position.x, + ) + + ##if y not on mx +b then no move + jax.debug.print("Player X: {}, Player Y: {}, car_direction_x: {}", player_x, player_y, car_direction_x) landing_in_Water, between_roads, road_A_x, road_B_x = self._landing_in_water(state, player_x, player_y) @@ -243,8 +244,61 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: ), operand=None, ) - road_diff = state.road_diff + player_speed - jax.debug.print("road_diff: {}", road_diff) + + road_index_A = jax.lax.cond( + current_road == 2, + lambda s: road_index_A, + lambda s: jax.lax.cond( + self.consts.FIRST_TRACK_CORNERS_Y[road_index_A] < player_y, + lambda s: road_index_A - 1, + lambda s: jax.lax.cond( + len(self.consts.FIRST_TRACK_CORNERS_Y) == road_index_A + 1, + lambda s: jax.lax.cond( + self.consts.FIRST_TRACK_CORNERS_Y[0] > player_y, + lambda s: 0, + lambda s: road_index_A, + operand=None, + ), + lambda s: jax.lax.cond( + self.consts.FIRST_TRACK_CORNERS_Y[road_index_A+1] > player_y, + lambda s: road_index_A + 1, + lambda s: road_index_A, + operand=None, + ), + operand=None, + ), + operand=None, + ), + operand=None, + ) + + road_index_B = jax.lax.cond( + current_road == 2, + lambda s: road_index_B, + lambda s: jax.lax.cond( + self.consts.SECOND_TRACK_CORNERS_Y[road_index_B] < player_y, + lambda s: road_index_B - 1, + lambda s: jax.lax.cond( + len(self.consts.SECOND_TRACK_CORNERS_Y) == road_index_B + 1, + lambda s: jax.lax.cond( + self.consts.SECOND_TRACK_CORNERS_Y[0] > player_y, + lambda s: 0, + lambda s: road_index_B, + operand=None, + ), + lambda s: jax.lax.cond( + self.consts.SECOND_TRACK_CORNERS_Y[road_index_B+1] > player_y, + lambda s: road_index_B + 1, + lambda s: road_index_B, + operand=None, + ), + operand=None, + ), + operand=None, + ), + operand=None, + ) + #jax.debug.print("Player X: {}, Player Y: {}, on road: {}, jumping: {}, speed: {}, road index A: {}, road index B: {}, current road: {}", player_x, player_y, is_on_road, is_jumping, player_speed, road_index_A, road_index_B, current_road) return UpNDownState( score=state.score, @@ -266,7 +320,7 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: road_index_B=road_index_B, type=state.player_car.type, ), - road_diff=state.road_diff + player_speed, + step_counter=state.step_counter + 1, ) @@ -291,10 +345,9 @@ def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: road_index_B=0, type=0, ), - road_diff=0, + step_counter=jnp.array(0), ) initial_obs = self._get_observation(state) - return initial_obs, state @partial(jax.jit, static_argnums=(0,)) @@ -316,7 +369,7 @@ def render(self, state: UpNDownState) -> jnp.ndarray: def _get_observation(self, state: UpNDownState): player = EntityPosition( x=jnp.array(state.player_car.position.x), - y=state.player_car.position.y, + y=jnp.array(state.player_car.position.y), width=jnp.array(self.consts.PLAYER_SIZE[0]), height=jnp.array(self.consts.PLAYER_SIZE[1]), ) @@ -395,7 +448,7 @@ def __init__(self, consts: UpNDownConstants = None): self.COLOR_TO_ID, self.FLIP_OFFSETS ) = self.jr.load_and_setup_assets(asset_config, sprite_path) - self.road_sizes = self._get_road_sprite_sizes() + self.road_sizes, self.complete_road_size = self._get_road_sprite_sizes() def _createBackgroundSprite(self, dimensions: Tuple[int, int]) -> jnp.ndarray: """Creates a procedural background sprite for the game.""" @@ -408,11 +461,13 @@ def _createBackgroundSprite(self, dimensions: Tuple[int, int]) -> jnp.ndarray: def _get_road_sprite_sizes(self) -> list: """Returns the sizes of the road sprites.""" sizes = [] + complete_size = 0 for file in os.listdir(f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/background/"): sprite = jnp.load(f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/background/{file}") sizes.append(sprite.shape[0]) - jax.debug.print("Road sizes: {}", sizes) - return sizes + if file != "background1.npy": + complete_size += sprite.shape[0] + return sizes, complete_size def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.ndarray, bottomBlockSprite: jnp.ndarray, tempPointer: jnp.ndarray) -> list: """Returns the declarative manifest of all assets for the game, including both wall sprites.""" @@ -431,14 +486,34 @@ def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.n @partial(jax.jit, static_argnums=(0,)) def render(self, state): raster = self.jr.create_object_raster(self.BACKGROUND) + road_diff = (-state.player_car.position.y + 105) % self.complete_road_size + + # Vectorized road rendering: compute all Y offsets, stamp via vmap, fold overlays. + road_masks = self.SHAPE_MASKS["road"] # shape: (N, H, W) + num_segments = road_masks.shape[0] + + sizes = jnp.array(self.road_sizes, dtype=jnp.int32) + # Offsets: [0, cumsum(sizes[1:])] + offsets = jnp.concatenate([ + jnp.array([0], dtype=jnp.int32), + jnp.cumsum(sizes[1:], axis=0) + ], axis=0) + + base_y = jnp.asarray(self.consts.INITIAL_ROAD_POS_Y, dtype=jnp.int32) + y_positions = base_y + (road_diff.astype(jnp.int32)) - offsets + + empty_raster = jnp.full_like(self.BACKGROUND, self.jr.TRANSPARENT_ID) + + def stamp(y, mask): + return self.jr.render_at_clipped(empty_raster, 10, y, mask) + + overlays = jax.vmap(stamp)(y_positions, road_masks) + + def combine(i, acc): + over = overlays[i] + return jnp.where(over != self.jr.TRANSPARENT_ID, over, acc) - road1_mask = self.SHAPE_MASKS["road"][0] - raster = self.jr.render_at_clipped(raster, 10, self.consts.INITIAL_ROAD_POS_Y + state.road_diff, road1_mask) - diff = 0 - for i in range(12): - road1_mask = self.SHAPE_MASKS["road"][i+1] - diff += self.road_sizes[i+1] - raster = self.jr.render_at_clipped(raster, 10, self.consts.INITIAL_ROAD_POS_Y + state.road_diff - diff, road1_mask) + raster = jax.lax.fori_loop(0, num_segments, combine, raster) player_mask = self.SHAPE_MASKS["player"] raster = self.jr.render_at(raster, state.player_car.position.x, 105, player_mask) @@ -450,6 +525,6 @@ def render(self, state): raster = self.jr.render_at(raster, 0, 210 - wall_bottom_mask.shape[0], wall_bottom_mask) wall_bottom_mask = self.SHAPE_MASKS["tempPointer"] - raster = self.jr.render_at(raster, 140, 26, wall_bottom_mask) + raster = self.jr.render_at(raster, 140, 25, wall_bottom_mask) return self.jr.render_from_palette(raster, self.PALETTE) \ No newline at end of file From d4353cd0455528d13bf9a398f154f17c889a5844 Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sat, 29 Nov 2025 19:26:08 +0100 Subject: [PATCH 10/76] add logic for different speeds --- src/jaxatari/games/jax_upndown.py | 72 +++++++++++++++++++++---------- 1 file changed, 49 insertions(+), 23 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 6651a57b2..d18f797b0 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -1,5 +1,6 @@ from jax._src.pjit import JitWrapped import os +import math from functools import partial from typing import NamedTuple, Tuple import jax.lax @@ -15,7 +16,7 @@ class UpNDownConstants(NamedTuple): FRAME_SKIP: int = 4 DIFFICULTIES: chex.Array = jnp.array([0, 1, 2, 3, 4, 5]) ACTION_REPEAT_PROBS: float = 0.25 - MAX_SPEED: int = 1 + MAX_SPEED: int = 4 JUMP_FRAMES: int = 10 LANDING_ZONE: int = 15 FIRST_ROAD_LENGTH: int = 4 @@ -145,6 +146,8 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: lambda s: s, operand=player_speed, ) + dividers = jnp.array([0, 1, 2, 4, 8]) + speed_divider = dividers[jnp.abs(player_speed)] is_jumping = jnp.logical_or(jnp.logical_and(state.is_jumping, state.jump_cooldown > 0), jnp.logical_and(state.is_on_road, jnp.logical_and(player_speed >= 0, jnp.logical_and(state.jump_cooldown == 0, jump)))) @@ -164,15 +167,14 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: ##check if player is on the the road is_on_road = ~state.is_jumping - road_index_A, road_index_B = self._car_past_corner(state.player_car, state) - - direction_change = jax.lax.cond( + '''direction_change = jax.lax.cond( jnp.logical_and(is_on_road, jnp.logical_or(jnp.logical_and(jnp.equal(road_index_A, state.player_car.road_index_A) , state.player_car.current_road == 0), (jnp.logical_and(jnp.equal(road_index_B, state.player_car.road_index_B) , state.player_car.current_road == 1)))) , lambda s: False, lambda s: True, operand=None, - ) - + )''' + road_index_A = state.player_car.road_index_A + road_index_B = state.player_car.road_index_B car_direction_x = jax.lax.cond(state.player_car.current_road == 0, lambda s: self.consts.FIRST_TRACK_CORNERS_X[road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_X[road_index_A], @@ -190,33 +192,33 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: ##calculate new position with speed (TODO: calculate better speed) player_y = jax.lax.cond( - state.step_counter % 8 == 4, + state.step_counter % (16/ speed_divider) == 8 / speed_divider, lambda s: jax.lax.cond( is_jumping, - lambda s: state.player_car.position.y + player_speed * -1, + lambda s: state.player_car.position.y + jax.lax.abs(player_speed) / player_speed * -1, lambda s: jax.lax.cond( - self._isOnLine(state, state.player_car.position.x, s + player_speed * -1, player_speed), - lambda s: s + player_speed * -1, - lambda s: s, + self._isOnLine(state, state.player_car.position.x, s + jax.lax.abs(player_speed) / player_speed * -1, 1), + lambda s: s + jax.lax.abs(player_speed) / player_speed * -1, + lambda s: jnp.array(s, float), operand=state.player_car.position.y, ), operand=state.player_car.position.y), - lambda s: state.player_car.position.y, - operand=None, + lambda s: jnp.array(s, float), + operand=state.player_car.position.y, ) player_x = jax.lax.cond( - state.step_counter % 8 == 0, + state.step_counter % (16/ speed_divider) == 0, lambda s: jax.lax.cond( is_jumping, - lambda s: s + player_speed * car_direction_x, + lambda s: s + jax.lax.abs(player_speed) / player_speed * car_direction_x, lambda s: jax.lax.cond( - self._isOnLine(state, s + player_speed * car_direction_x, player_y, player_speed), - lambda s: s + player_speed * car_direction_x, - lambda s: s, + self._isOnLine(state, s + jax.lax.abs(player_speed) / player_speed * car_direction_x, player_y, 1), + lambda s: s + jax.lax.abs(player_speed) / player_speed * car_direction_x, + lambda s: jnp.array(s, float), operand=state.player_car.position.x, ), operand=state.player_car.position.x), - lambda s: s, + lambda s: jnp.array(s, float), operand=state.player_car.position.x, ) @@ -449,6 +451,12 @@ def __init__(self, consts: UpNDownConstants = None): self.FLIP_OFFSETS ) = self.jr.load_and_setup_assets(asset_config, sprite_path) self.road_sizes, self.complete_road_size = self._get_road_sprite_sizes() + self.view_height = self.config.game_dimensions[0] + # Precompute offsets so repeated road tiles can wrap seamlessly without gaps. + road_cycle = max(1, self.complete_road_size) + repeats = max(1, math.ceil(self.view_height / road_cycle) + 2) + self._road_tile_offsets = jnp.arange(-repeats, repeats + 1, dtype=jnp.int32) * jnp.int32(self.complete_road_size) + self._num_road_tiles = int(self._road_tile_offsets.shape[0]) def _createBackgroundSprite(self, dimensions: Tuple[int, int]) -> jnp.ndarray: """Creates a procedural background sprite for the game.""" @@ -502,18 +510,36 @@ def render(self, state): base_y = jnp.asarray(self.consts.INITIAL_ROAD_POS_Y, dtype=jnp.int32) y_positions = base_y + (road_diff.astype(jnp.int32)) - offsets + tile_offsets = self._road_tile_offsets + tile_count = self._num_road_tiles + tiled_y = (y_positions[None, :] + tile_offsets[:, None]).reshape(tile_count * num_segments) + tiled_masks = jnp.tile(road_masks, (tile_count, 1, 1)) + tiled_sizes = jnp.tile(sizes, tile_count) + + visible = jnp.logical_and( + tiled_y < self.view_height, + (tiled_y + tiled_sizes) > 0 + ) + empty_raster = jnp.full_like(self.BACKGROUND, self.jr.TRANSPARENT_ID) - def stamp(y, mask): - return self.jr.render_at_clipped(empty_raster, 10, y, mask) + def stamp(y, mask, is_visible): + return jax.lax.cond( + is_visible, + lambda _: self.jr.render_at_clipped(empty_raster, 10, y, mask), + lambda _: empty_raster, + operand=None, + ) + + overlays = jax.vmap(stamp)(tiled_y, tiled_masks, visible) - overlays = jax.vmap(stamp)(y_positions, road_masks) + total_segments = tile_count * num_segments def combine(i, acc): over = overlays[i] return jnp.where(over != self.jr.TRANSPARENT_ID, over, acc) - raster = jax.lax.fori_loop(0, num_segments, combine, raster) + raster = jax.lax.fori_loop(0, total_segments, combine, raster) player_mask = self.SHAPE_MASKS["player"] raster = self.jr.render_at(raster, state.player_car.position.x, 105, player_mask) From d92272ae140647e4ea078b5219d36c7390cd5c01 Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Tue, 2 Dec 2025 23:49:59 +0100 Subject: [PATCH 11/76] car now follows road with loop --- src/jaxatari/games/jax_upndown.py | 69 ++++++++++++------ .../games/sprites/up_n_down/roads/road1.npy | Bin 0 -> 101108 bytes .../games/sprites/up_n_down/roads/road2.npy | Bin 0 -> 59492 bytes .../games/sprites/up_n_down/roads/road3.npy | Bin 0 -> 84032 bytes .../games/sprites/up_n_down/roads/road4.npy | Bin 0 -> 83972 bytes .../games/sprites/up_n_down/roads/road5.npy | Bin 0 -> 72344 bytes .../games/sprites/up_n_down/roads/road6.npy | Bin 0 -> 90704 bytes .../games/sprites/up_n_down/roads/road7.npy | Bin 0 -> 66224 bytes .../games/sprites/up_n_down/roads/road8.npy | Bin 0 -> 76016 bytes 9 files changed, 46 insertions(+), 23 deletions(-) create mode 100644 src/jaxatari/games/sprites/up_n_down/roads/road1.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/roads/road2.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/roads/road3.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/roads/road4.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/roads/road5.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/roads/road6.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/roads/road7.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/roads/road8.npy diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index d18f797b0..9e30ea7cd 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -21,10 +21,10 @@ class UpNDownConstants(NamedTuple): LANDING_ZONE: int = 15 FIRST_ROAD_LENGTH: int = 4 SECOND_ROAD_LENGTH: int = 4 - FIRST_TRACK_CORNERS_X: chex.Array = jnp.array([30, 80, 140, 80]) #get actual values - FIRST_TRACK_CORNERS_Y: chex.Array = jnp.array([105, 80, 25, 0]) #get actual values - SECOND_TRACK_CORNERS_X: chex.Array = jnp.array([20, 50, 80, 100]) #get actual values - SECOND_TRACK_CORNERS_Y: chex.Array = jnp.array([20, 50, 80, 100]) #get actual values + FIRST_TRACK_CORNERS_X: chex.Array = jnp.array([30, 75, 128, 75, 21, 75, 131, 111, 150, 95, 150, 115, 150, 108, 150, 115, 115, 75, 18, 67, 38, 38, 20, 64, 30]) #get actual values + FIRST_TRACK_CORNERS_Y: chex.Array = jnp.array([105, 65, 7, -50, -98, -163, -222, -242, -277, -362, -420, -460, -492, -520, -565, -600, -633, -683, -733, -793, -820, -845, -867, -895, -928]) #get actual values + SECOND_TRACK_CORNERS_X: chex.Array = FIRST_TRACK_CORNERS_X#jnp.array([20, 50]) #get actual values + SECOND_TRACK_CORNERS_Y: chex.Array = FIRST_TRACK_CORNERS_Y#jnp.array([20, 50, ]) #get actual values PLAYER_SIZE: Tuple[int, int] = (4, 16) INITIAL_ROAD_POS_Y: int = 25 @@ -54,6 +54,7 @@ class UpNDownState(NamedTuple): is_on_road: chex.Array player_car: Car step_counter: chex.Array + road_reset: chex.Array @@ -99,7 +100,7 @@ def _getSlopeAndB(self, state: UpNDownState) -> chex.Array: slope = jax.lax.cond( trackx[roadIndex+1] - trackx[roadIndex] != 0, lambda s: (tracky[roadIndex+1] - tracky[roadIndex]) / (trackx[roadIndex+1] - trackx[roadIndex]), - lambda s: jnp.inf, + lambda s: 300.0, operand=None, ) b = tracky[roadIndex] - slope * trackx[roadIndex] @@ -109,7 +110,7 @@ def _getSlopeAndB(self, state: UpNDownState) -> chex.Array: def _isOnLine(self, state: UpNDownState, new_position_x: chex.Array, new_position_y: chex.Array, player_speed: chex.Array) -> chex.Array: slope, b = self._getSlopeAndB(state) jax.debug.print("slope: {}, b: {}", slope, b) - isOnLine = jnp.less_equal(jnp.abs(jnp.round(jnp.subtract(new_position_y, slope * new_position_x + b))), player_speed) + isOnLine = jnp.logical_or(jnp.logical_and(jnp.equal(slope, 300.0), jnp.equal(new_position_x, state.player_car.position.x)), jnp.less_equal(jnp.abs(jnp.round(jnp.subtract(new_position_y, slope * new_position_x + b))), player_speed)) jax.debug.print("isOnLine: {}", jnp.subtract(new_position_y, slope * new_position_x + b)) return isOnLine @@ -301,6 +302,22 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: operand=None, ) + player_y = jax.lax.cond( + state.road_reset, + lambda s: 105.0, + lambda s: s, + operand=player_y, + ) + + road_reset = jax.lax.cond( + jnp.equal(player_y, -928), + lambda s: True, + lambda s: False, + operand=None, + ) + + + #jax.debug.print("Player X: {}, Player Y: {}, on road: {}, jumping: {}, speed: {}, road index A: {}, road index B: {}, current road: {}", player_x, player_y, is_on_road, is_jumping, player_speed, road_index_A, road_index_B, current_road) return UpNDownState( score=state.score, @@ -308,6 +325,7 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: jump_cooldown=jump_cooldown, is_jumping=is_jumping, is_on_road=is_on_road, + road_reset=road_reset, player_car=Car( position=EntityPosition( x=player_x, @@ -333,10 +351,11 @@ def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: jump_cooldown=0, is_jumping=False, is_on_road=True, + road_reset=False, player_car=Car( position=EntityPosition( x=30, - y=105, + y= 105, width=self.consts.PLAYER_SIZE[0], height=self.consts.PLAYER_SIZE[1], ), @@ -439,7 +458,7 @@ def __init__(self, consts: UpNDownConstants = None): temp_pointer = self._createBackgroundSprite((1, 1)) # 2. Update asset config to include both walls - asset_config = self._get_asset_config(background, top_block, bottom_block, temp_pointer) + asset_config, road_files = self._get_asset_config(background, top_block, bottom_block, temp_pointer) sprite_path = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/" # 3. Make a single call to the setup function @@ -450,7 +469,7 @@ def __init__(self, consts: UpNDownConstants = None): self.COLOR_TO_ID, self.FLIP_OFFSETS ) = self.jr.load_and_setup_assets(asset_config, sprite_path) - self.road_sizes, self.complete_road_size = self._get_road_sprite_sizes() + self.road_sizes, self.complete_road_size = self._get_road_sprite_sizes(road_files) self.view_height = self.config.game_dimensions[0] # Precompute offsets so repeated road tiles can wrap seamlessly without gaps. road_cycle = max(1, self.complete_road_size) @@ -466,22 +485,26 @@ def _createBackgroundSprite(self, dimensions: Tuple[int, int]) -> jnp.ndarray: sprite = jnp.tile(jnp.array(color, dtype=jnp.uint8), (*shape[:2], 1)) return sprite - def _get_road_sprite_sizes(self) -> list: - """Returns the sizes of the road sprites.""" + def _get_road_sprite_sizes(self, road_files: list[str]) -> list: + """Returns the sizes of the road sprites limited to the configured files.""" + road_dir = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/roads" sizes = [] - complete_size = 0 - for file in os.listdir(f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/background/"): - sprite = jnp.load(f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/background/{file}") + for file in road_files: + sprite_name = os.path.basename(file) + sprite = jnp.load(f"{road_dir}/{sprite_name}") sizes.append(sprite.shape[0]) - if file != "background1.npy": - complete_size += sprite.shape[0] + complete_size = int(sum(sizes)) + jax.debug.print("Complete road size: {}", complete_size) return sizes, complete_size - def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.ndarray, bottomBlockSprite: jnp.ndarray, tempPointer: jnp.ndarray) -> list: - """Returns the declarative manifest of all assets for the game, including both wall sprites.""" - roads = [] - for x in range(13): - roads.append(f"background/background{x+1}.npy") + def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.ndarray, bottomBlockSprite: jnp.ndarray, tempPointer: jnp.ndarray) -> tuple[list, list[str]]: + """Returns the asset manifest and ordered road files.""" + road_dir = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/roads" + road_files = sorted( + file for file in os.listdir(road_dir) + if file.endswith(".npy") + ) + roads = [f"roads/{file}" for file in road_files] return [ {'name': 'background', 'type': 'background', 'data': backgroundSprite}, {'name': 'road', 'type': 'group', 'files': roads}, @@ -489,7 +512,7 @@ def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.n {'name': 'wall_top', 'type': 'procedural', 'data': topBlockSprite}, {'name': 'wall_bottom', 'type': 'procedural', 'data': bottomBlockSprite}, {'name': 'tempPointer', 'type': 'procedural', 'data': tempPointer}, - ] + ], roads @partial(jax.jit, static_argnums=(0,)) def render(self, state): @@ -512,7 +535,7 @@ def render(self, state): tile_offsets = self._road_tile_offsets tile_count = self._num_road_tiles - tiled_y = (y_positions[None, :] + tile_offsets[:, None]).reshape(tile_count * num_segments) + tiled_y = (y_positions[None, :] + tile_offsets[:, None]).reshape(-1) tiled_masks = jnp.tile(road_masks, (tile_count, 1, 1)) tiled_sizes = jnp.tile(sizes, tile_count) diff --git a/src/jaxatari/games/sprites/up_n_down/roads/road1.npy b/src/jaxatari/games/sprites/up_n_down/roads/road1.npy new file mode 100644 index 0000000000000000000000000000000000000000..b75360c543d5430074567626243cb2535c8dd412 GIT binary patch literal 101108 zcmeI4J&zql6NQ(Q#Hqv07Nm7xfeH8lWFUk@KxA29ge8R08c6u~6a3)CSLg+GqnuBv{fXSutt|NZLQFTeiogR4KT{=E6&<(u#CZa%-h`Rn(mH@DX} zKfb>E?e4`d&tKpD@bd8WFJAoo=H=V3zxnCKuP@)~kDh+^?DqQUvrlfXKmGXj`tNK1 ze;;05UEP24&y)B5dG+eieak;;zxeMz4>g~bXLZP>_EUZstDUEs?;3s>t3MZOoZC!$ zhOsc#o7XC1t?N7GUCBMS)&uLibiK9n31j(wj7+wEM#ESb3uF6Zp!GdLMqJ&x9_Bme z9%Bx-C&n{hwD(+bp04ga?=a?_XYcbkC%?-Sn8V9Cv0-dJ#&{38`W#Lw zYg#bI{oS4!&wSDEJ_nqat4Gf>jJT7-y)VXadnFtww%MA2jjQaTNoR2H#z4{W?b#;H6k{Qg|Xh8&Bz#^>)tc! zouB=<+O4m5e%5meW4&u~^D~}3&8FFptKEKk=Vv`(Z2fnod*+PinJUfoN~?Uz4`X30 zj4eMO>phc)p38?mH*>$P>Q}zIALCVi`dsho!#$?fa8IRGKINW3FiD&5Q>Y0z;FN}pT#axx?Ju=3bRcBOP&w5;4-CwO|Ju2SyPhpHXd)K7) z* z=A*AP*SlKfSH9X$tyjLvuY8rCKJ(F6n(JL1#+JWlt7fRqMENSe@>PEN%tv2ou6K19 z3u9qy`SVKmIZ`w1?Ys55@vMLJdbQsv_ZV|*y(d*K>+h{uPixKk-Tc+{YQHeHd{=Vj zRjtmT>a#vKUv<6OFN}q;FjoJ5jL#Wo+FBn!Ps`u`|E&3Nk6qdG(U`mvEIywvH2L|yIE;6q~hswy{k`QjC(508dN-eu5YcmUu)gEUgmS_ZM~lP zln!IL#}u=QA?1g$Fc!x2@AIqYNX@Xf&(G(o+O3!SRo640(uaFY#Z+smu2=Ob-`&4; zysFRj-TIinx}N!z4rA@#l{`ar2CmQiN;AKTr_c4S4r5`AOi{Z`F@BGXaYnA@jHumw zjAy>8uI8_s;*}})XeW{=A(A=F`oIVx|+YbUezDQmhVb6GoFd7Resm! z{-f8c{lZunt1@~CjB!S;=8V*OejcadmCt;8(}#O(_MX%kDxdwgj(791p4oL6TmD-F zE)yzV`Li>m>JMXKOfgdRXJky>3(Dur7*C&C#nX2+_iL?3*K-PEqia$#b?ad~^DEzt zXFgYRzt(zmJz?zeJ!h`YRK?SGwc5YxkFF<-g)zlkmFXE6Q!`LLXT*5=4{3gmcGS#w z{QB{4`mR3aY4!K${SS<(+DF&KdffdPukz7%wYwkVm3He<@oGQ%DxdPh*dzIjsTry> zq3>$7zw3|QKa7Pj#hlA9bZ*{*@7~>wC3H_>OephpUuBYEx ztNrM^dUU`%qXAnN;BWuI*f&} z-VDyb7-zuen0_yP?&oT19s6NCeOEKzXnkO;_nz)d7|;5Y&-krt)en7FuibwbYyYz( z_e{wQ*VA`3^NrSFER6MLZU)Bq-1eTy+WFb9t64wS)9jPe<;H1oCAVJwWbW{Uatz!;xX_P95#`lavc*7F*{ z7;9$F%4fXN%(pkq{@EXWS2JH#hq3LyD|yCbM){0an)#|ajD;~WS7nOp_sCf9nQ?|| zYt=7(S6BO6gE7|3o|VsdrI~MUn*Fms`mSa^SBJ6r-jzIKGNXLPE6sea4r5^~j4c^Z zGf}?#9L^Tc{#ieLS2Lf|Te!zqBYRdp!a^#=A#Z{Va#QU@nnj62F6&k>XE*yx!>M2=fVEyyPEmJnEr1H&TMu~ zX4nsXS2JH23u9!$Ws3347skREnIJQ+X1*C2V-K8x@)@r*^X*NufA&Y;)nROTk8y^a z`QEhZm%gjRSQsNSe1_<|I*f%eG84x9pJ8)`uI9`yO}q1A{2uNxb>{5Z)hhqqeRp1r z4`X>(s%M1ExH^o5u`s5;|8k!zH(s5=?7lk>H+~QInA@`(KRZLJes>;jd>G5S(q+bt z4`X4>WyX!4fibryH(tFI~J4d;6;Xs;}x-`M2=>n5wbrS7)TwukBaoqt>hb)cP>?Ot`uliH@)%q}|??E-g-o85Hs;}x- z`PKR`7RD4Km%v!>8L0fdeRYOaU)8VjtM$rnagV7QtG?<%QnjEdgX_)Fc!w<|CzQr!>X_D5tU!9ullFlW2(lguX<4V)%vx4 zRe#l2{i*zFebo8Kt9}>@V^xOK`WYBgJ#5`q_lWDWU$vg`qxBk$ zvB#~?i1oT!_2>HBU#(~SXdTApdsmW?Ft+^cVYV5uURNsyT%Y@^^^6~_bB}HR9wQTE zivHI02*%j+Y%^m0u4cck=6+mHe``97pvlww%Y&z literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/roads/road2.npy b/src/jaxatari/games/sprites/up_n_down/roads/road2.npy new file mode 100644 index 0000000000000000000000000000000000000000..df134e6659990cac0bee8be16975e389cecc5d25 GIT binary patch literal 59492 zcmeI4v5uTI6orSDhF6$wh_oppKq!RneeayFneWo{IX&AS!+!01yRYe&)?sYg`6|61r}B{rS2HKa z(=V;VSQsnK6!VS1SnG4g9<|n7ud7)v$J4K^OEAV7)tczLn)5A9yZvB1`=NZsm)2pd zzqgW1xSEV`JpI}_jD@is3uCpJTLNRfom;hMtk>1d&++u<)-4$08dXo| zyPEU$rrq^0p6gLQ<6G-6*8Z&|6RsvB98Z649mc{~Z3dUfSa0W=^|+c0a6J9q^c;+_ z2dZcEUCsGgYqws;vtH#her_Gc>c5p_!qsGiQ@`i81th*I}&uTS+EdO=dWre(ySrg|XUK#@IvEbNY_vT<2(KzEiQBlQHy1(_w7-TS+EdO=dWr z{%AUkg|XVqj=&gupw2mcS93mUHy`6wKKi5UFn0X&y4q7R<7(#Tc>1I3Fc!vYGaSak zSQyhkldHFeerwG=m6fDTc_Gn`lU7d$^7(}uESXQx01|+vD{;1V5$yd zhd(PU@g7rqJe47LJnOkM9mdMPm1OYJnBscFSQsM{rF9q!W2Kp5zAzTX$V6!!#==-> zrkF2`g)uTwT8FVPR+=g1TOwop-GQswBWgDvJMXKOffRmACWP2FDRcqV?2Fo6;I#QoNsg;#vUL2|8#0k$qeJ^yPEk%*I_J- z)n+)1g|RSp%)r#QhvTRASmmSd>Z$e3J${6H%-s{lv&YJ3ywWNkeOIgbd-ucG@$W9& zJtH$5Pv6zdH<}J(VXQW@OJs~auWtejOl;Z;qD2U;duJ4X1?Ba7z<;y8C?Qn>=Ew;`mMF9hrX*@ug9IQ2gclm7N8RauxY3A!qhp{kLo5>LwV^4U_>G!5pz4Tq(+j{3< z%Jts5DXS~wPH@6OBVXQWTBQVAu@Ep@0O{?|L zclF%;Xu+7<6UMVg%4dA3UxU90udcXjLjgt7VGO71zCQ9k39X1>}wjD@i<)*l14 z_k{U*FDak#N;6++UE&^7YfwJd#CZCbrd2=aySjEi!&v|KB$*)-^j*z-rF9q!W2Kp5 zzA)B5W85>|Q_5$&(#$tipTn4{QTeQy@${n)&w=U}eOH(6cNja|zta7!k&n!94*ITU zzNtElg)uTXl_`!7W8*Q#JyeX)cQxm8^~W&g*2H+$qS`s(S6?_R$8;n{Hi^QS+*e)jhM>z|(f^6V{s@6kt}KD>VP$%hZGKYst= z^zt--_!2^T)+gy|6^w68>VHZa)}Ly-UOFhl@|y zulOtH*OLB=zv8b;_ZQy>OZMX&w7Fjwt3CAAYppls{OVngtH=1(ezw4Izb;li=&jdU zFaA2e_oF+1#<%7@aEw=2>9zRyEB=bV+CM*`_57@XQ}b24yGNx*-qjz+pW^%)uRmOK zKQ5-;+OWG{#=HAvzQV5lI6nS5{rgePSMlR{PPtZkD#dFp?mz=?0=Tgg|I^20M{*dmm z?z0{jAFIwK)``F7ok`9tTt6}1S?kRZx>UqH+Jh`ysC%s3R6$T zlV`qzr}&HODa`d%JbC65Gv3A2>y6!d7_aJKyu#E|@#LBB;P|Wm`;qf;v6^31ubQvw z8}sp3{H6Rh_KSPHMt^aSsy~XS-XVPGb2O!|_(Qsf?khd=)FH1h^Bp|+OU0}8$y1NK z!pv7#=_y|6D_-dSglW<>yuZQ`3|1qFSR~- zu1{WJ=J&>G{VLzPUh28^_0FeW3&&sm-;b?+_u%5MyML#u+K(PH8Ed%pxbfq)mbk9g zQ_Xkwl=G|q`NhxQdSllgT)&!6y%rw))#3-PT2D1!)#vhVJ#M^Pj~idDr<(878-L~e zQhsy%#Mt;N{))d&`}0y%{}s=s+>fhQ z&2Oz2fAxPqE;+A#v#hRH&EHqk?q2*Af5l(kufNzY7w?`ue&2nr=;|?kr*jG4hx2P| z`{X=btoq^VF<$9!iI2bTzaQ0n=@%C({VnnFSNyf4pSR#I_xxZy`%9kKjb}bF^%Q2l z!qijo-M>{OQ(a( z)wS>b;IDo6$M>rBRP)ER4A-t6tVdYz4NKp!ts~(3;V};DxUEQQ;%5Xlkc5Ry%t`Bzt}GqbN|dI?_#A#-o>gO>Zy40 z)qLuW@f3fJ{NJ#IeZUEI5VSI@1N@m!DCjqjaLy)llz`u}{)aes$v&fmqjC zg{%23{rD^Xiofpv88mw`#(bW}+;B~PD&Bn#roMAM#*=q3_0Aa&=hwN_=k}ZN+>hcJ zU&U%Y@-D8{cg+0w>+a7o=+S%ng_`7DOuci)@mKt{q(8UdFZPG$0r_KNRS$U=yZc#M ze_YQLe=S{)tIzoHeuitsGk$5T>LKr9xBjK|$MwWt{ojw{{tnlir{Wo}F!dBxdgPT} zi;utJFZzkN#^~hIn{1ty)y1$m(zw>O4KMPy(n)~HiHb=B|7+0Z`NW#ZyO?@CagX!M-4ElrAH_3% z>sa-Jyo;CYU;K4`??=v?eo;K*6{cQK9Dl`MJ^e(j_-neqoX^UZjOBU6waB}edKEnQ z%guM|WjxnYJmW7NtA3GpanJt7U#I``A$Oki3-igln0gf)f5l()U&T+%jlb6CFV0Z; zgS?Bmo`a|O%dLm;tVi*T4{kpLs-NUtTN4kz4f>^D8ySU-1|HL>zx@ z&m)zYGWW@otJ}&K^&C9q{HoM8G_Yp;xGC!=hyZ;QJE?I;bN}MeDV&S za(=1x$h+97H?jAhEw|r{XHON+_!f@8?$4{$uXE(+2N!cs%qQ>Q_$&URU*fOzdxA4{ zFz4@LuEl)v6+GqqqQ0t!yo75g%@E7*=QF?fP}y=K#B_Ucb`sT|M~t^6cuP2bX6bzJGc4 z`Ov|G`)Pyz%qm#k>2;NBOTk{_^F;zVaEpKlfp*(x1cH(R)-q|JM5##@;-` zt9%{ShcW%G+#k@1-2Dn;{pTd_Txl|)@>QIEQ|mAm#(Fb10%Lqmd-tTZzdAqT)I1O4 zuBKmaeF|edld93pcm3$UG|hT=KE|ote6FXaUvC}8rhixRPRRhZo6q&s^y{s|SQraq z$NyGD@1C^wS964MYMzI2SJSVz&K%1e3uBpMWWb$U%%@*(ojH~{7RHX}*xb*hdnf9= zj4RFkRlbVTueF|HjM0oG8Cj!>Ghb=?&8>M}o`-R2m9OITQ<{EN9mbabuH+q)5fx{?()6q9Fc!whT$L%V z4`X30jOpJK@7*Ilm#$XlW1O1%xpC$z&Hbu6b1ZX=%vG7<`Y;y8!kGWA?A#;M(WuBWD-tCujxR85S#+U-f}e4dy7j8pf{=YG`O-_>Dk{CA~df^k>7Off%< zg)zk(nQ?U(yFX(O|Ndit%7=GI55}o`=W{=5?(gaYV``1slW}VLGfwU1b3HZvrq=2_ zj8pTxj8nV$Tu)6uSBJ6FW|Pxrb!UuIb1%lJ-F&X6rk|_BSQsNy)GkxZAAvERnLThd z{d#MjkNYuB?dG>$ul#sE>LoI!YGT~g>_KnM^Kn1Msonh6>y;nRM;*r2_pVe-Fz#wH zq;wbyV`1##=+{=Jb@Z&8&%WSJSVm)&BIa=C}4!=V6@thq2?oBg`IE zHSd6${)|(rd=+>7n6EVV4`X4BOels_oPJ>}jFAb&kc!hUjD;~Wp%_we`h~GDMkW+P zDo(#J7RJbgVo1g57skREnNSR=IQ_!dF=Ol@d**8TQPYocYL&0zt{?MV9mc|#%VZey z|Gc~E8SlW=u75S3=d1du^(vk@mO17!shDDX1jg6{SF>hUyZfyz-|ZLk@59*WJ!8$T z9=-nR`Q3goKaBN%S61(M?)7BG)nP1*DTZC9n7>BGSfi^|&(|J*^k}!QeAu5gGFIu& zVJ&;KHQl;@VeI&s-5OJ@H;n0jpQ3e-cn53Ot^2ozIo7%-YiEe{_3mF7%lEG^mUkoP zPHP>;!dPpj=oiM08SC8>-oe`R)cqUr``FYy3S)PFeaG!b7(1SmZVyMxpZXaMV`0o? z&V8mvV9f0S^WC1TEuZ~xb+sR1Z2fc0Ws3P>?C$qb_UE9hJsN%efAz%qv1Xf z_QTb!>s`YfYu%I4{9J~Z-@D!=Fvgl)-MdFi_jmin{NDA2vGKo;kr`Kqv3&nh47f}& zU-?(#BhImE&7-d;Bd(sh{y7+Pd&2ywd$hFu*bi4vT~8Pr@B5g`6!XJazJDnO$c(F% zf9v=NbF6jEqxrcEF@NrQr(ldVyL#>(&Aq?dFXqpE{xG)u?_*@f)nP2(zZ3&5Q_NTX zz2lix2|bxe(a~Kd-p4h zjXx({rkEeb#%G}S9g-PW_dZ|d*!mcD8DoBH=1yUZXL5Dx8sc~NIdSiU`K|jI#_rxv l>;BH+M`m2z`pjW0jFG`Gwmv3W-zoi`K7D6j`OweR{{nl1_HO_H literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/roads/road5.npy b/src/jaxatari/games/sprites/up_n_down/roads/road5.npy new file mode 100644 index 0000000000000000000000000000000000000000..b390224b85f5ce62d91cae78270c69d5d3ae9132 GIT binary patch literal 72344 zcmeI5O>SL96ok!`Jy?1(Lu3XPAR!iz8Hfo&jF3f)7$Ab|zyz3P1MI*n@D5Gh?B%zUp0THN1rWQ-Y5^PbQjU334nwVI#T8<8>Yv9@+I zQuU+xwew;>H81PK*!DSr8L3Q1^SS@pTFuMrg|RS3CX^0i+h_PrH74N~`S~1}IoJXy9 z>%&+WbD44LXJkyxNMC99b*uW-_SODepZnu_`bu*>bq~gP%~H*o*^j=`Tu-gmxAwW8 zQqBFcAAO~{o;r+;|6a*^R;tMi`_WgL>#4(77$XykDf(PbJp*HECi<>sCe&`f+V$Lz ztGOSpr|)X^qYh*1pJQZ(4A6Hq`%#CnFy=DFdNM^F#=;nxaW$Euo`JFZ`xlR=A8rP- z*PqT_)*jOQ3LZJkl?{+y4R`=Re@_N%RXFvgiV1ASMs-|Sk=&-qln^4YJp z4rAj#kC72FK;PBuS6hd%Fjku>_FI85zHT|=Y?|{Z&G}Tl^4YIc@4*;n*9WAiI<+EOC_H%U@3u9qy z%Ye#6`R?nmcYW#pxt{x{?`rl_dW3U~d*sZ@XT8$wH<~WpKi6|U`mSa_rNh|voK%^$ z_SHRC*XMfH(|0xdDILbbSQy)W&y}wmHzU@w-`<-0qjvYF*3);j+RupR828MXmCt&m z*>7)Md%x_*`RKcv{iws(ct5+X?lJG7((K22`mSa_>M$0@T&7r0rl@CNjC)oy(swnl z*ITRomhOlB*^j=f*)NRg-%~KN)>_@u((AK7`_Xqb`-QPEMkZXQSkHc8ER2x}GUICY zTY)iVz!~Ygn*Ca9wZGEqyZu$Y^26Bn9Ajps_tag_>r=b^SkHdcZavpihp{k5Cc@am zd;e}H^Ev9;8L?k${nSjSUn{--X|Ja~XaA>aoUg5Tj&c6d%(%X_=KiTmug`u;^LnbD zei+MpC7GZOV?HCyP-$k|TDRVh>QA5Z(pNg?n9nFPudUTRE4`lTPoLMLuk?(Jab{Pm zO!$8oxcyhJXFf`|U~Ke^RVMUDpQrc!n2*w7?0mZaPrNr%yq;o;ei#d5WJqZ;r1W!O zOl6>amHFzv%18NOEa#YdjVOkcAI8F17%Tt#kLq=%e02|2_fO18 z?0oZcvnsRIe8rIRRUTn1j46hcAI7%NTg*&lw7SncT&?n%@f>qAV7o;`_aDZ>7#UJJjOF=DF`yVyzUohZhI35KtbES=+_cI^`K5V3@C<_ulg%L=UC3MF!mt--VifVudCI4=Hcqv`HXmuxtXxOc1E-HV;-)q zJ#QFW@4eDxiuGYE&-jV~GUIC1-}P5G$K07$t9Kb*q4W`n~7v!5C+D_1+oneSJ4C*6+RlFgE`4 z7@2W(7|ZjQV!&mJ^{RhsKj+wZ=NK7rb!%qa{w)}DGhw|u!)o=+!_}kZ5yoo;m z{rC2Bj`css$cU@=X2!k#|6$C{g!S$oSF2|puAVKAFqU(yHM3zXjD@lO4AjoteJ^Bv z?fvxb_g~I2&g|;mGmUmVH!s$YHjlfXFxLK@Br~pd@9AjuVJwWbW-^SmXYBs|#be1w JYX(~P{}(_T7vul{ literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/roads/road6.npy b/src/jaxatari/games/sprites/up_n_down/roads/road6.npy new file mode 100644 index 0000000000000000000000000000000000000000..2e16b2319694e65dfc08f7f616c733d3494c8e58 GIT binary patch literal 90704 zcmeI5&5B(`6ouQV6Q@oBJ!=igATa@P=)gfh5kbUJ2qt1c66vTDe?XADfKx|#g1q51 zS7=GTYqGNLt~%#bWeJyL-L?0w+Pl6}x9$A>_1Tx7e*X1aCqJM3a(4ILqi@d7J~}=7 z^`~oRS5D8qeR%%k`JD&1AD-X6w}1bWJKsOL_w4;g-`)A)-ZTB)wHqH^IlXrCgDa;u z-oJAC+o}KG+b1U{yDxsf+Vyq+{yV$MtM~r4UjNpuTf53@?f#ZvtTLZ{?%Fe2d;QVp z7smSD;Vt+KWBQCnpHT}=Dm6x*U+yvej9l<>^Y**DHC|Q5*4Xp&Ywj6)UibL<=rigu z6Q}0r^9y6^+)ZJu?c7Iy4jmqP#Ew3{++%H@r!aOn&oFj4lOCK@YUUo(XEgeZDtxra z9esYe$MiGOf|E*(Fs9FF^ch|7aq)e!rGKo>Z|VJVz5PAL_gd@GXViKgZoSdx7sg)y z`}{5&t>VL&KBLiR)QXj3>x@3X+++F~IhKu9d0|YS(daX3#mcdDMxS5qG5w4j%SNlb zFs9FF^cl5c<=8r-&oB3wenu|XxcENV(mz%iTYA4-Z_gOtYpqA0QR{iQ^+umx7}LK_ z;WE)WK8)!z8hu8snQ`llKEDy}G0xaMquJu0Jh{B9yq2EF5{$8@%6#^@toNqr(dHM% z^k@Cn7!G6l%to707}IC^su_)VkCD+=_1DsS@b|~NHD2tOd!X-X)@!ZzjJ2-GdhA*G zj8~fVX4CAS{n2+d>$TQlY@NHwIY%7qNz1BL6g|XI5vEB@f@wsJ>qigm@&Hm`S zn)RxB4aV3rd!X-X)*D@`{@I_3S3c`ibr>7(vyzOE0s5|Hy{ZmlVT=q`nPUEojP;t? z)vuoJ)_AcW?!omLuQb|$BkzX%y)gpE6w#r)9jD?cYVey&GlRz#@4x;oO2=r z%yoUnE6w#>9mc{~7&~M@%|!X`bGUXq`{91*Q@ioZr)E8+w|I}Kp6I(;&17pn`{VxT zQ@ioZr)E8+!`R_Hsb;pcug;i0HTO&3)vTv<7z<-z?C|%l^SN{^<2&M!r0;eQ*cIWYo0Op zqkK1>^<2&M!dMt16E0JXXT2~M#>fPjaW(6OvDatp>L2%aYrJ^I+=KGnc-C_@*V{A3 z_gZW2k@c1D#WoZ)gyhSA3wgdd;Z1sme$YjW&d-ouYAT+t9<&dUV<_1$?cKx)NXyoE6w$m z*4!W0S3cvZRX%-Jhp`v$tpA%I&tz%MJ#l^IGoD)I(|2_k3u9zPG01qOXJD-LOgOXu z>aXA{uYd37#Tm|~$%Lyp59ZSkV_}SpC{2bKPd|)> zF*2ew8Dc#B5g20+u4d27r$3u!zpmyym`|TNjOFVuGDA&fl+XAu7RJa7HJMR9uT*VUW{^XV%c#d{v^jD;~WNM@AJ zc8v9mdA{tRxezCNs>ZU)5nOjD@j${Y3^`&3lOX^jqsC?lJb{YWB>0`m<^F>uSz} z`Se@sFt+{AN;2VUGQ)iOt#ud+W38EDy%88&`nltbme%ao)!Z-h>G#$>7-LVaX3xy0 zKbvO1uI4SfsgZ1e5)?sY@&q^{-)ntbC==auPER2P*^%-b=2E13Qn)_ot`b+C$?lJaI z)$Ez|=#Q@1UsZD+tVe%o9mbZwC&@rnlNr{dzqAfxVXQZEVQhKE)W4ru^?Nhp)<1@^ z>I_uA+rwz_s{g9LwRwcG++)3&3}ay|j4jVV@0qiH^|MjsbG@bYG4~kvsCumWTyHjA zosY_2dtPB|{d-a|QuV{w<(-=0+B53CzUsf~_ujv%Kf*nxW~_SLn$LdJ{utj{tM#k? zF^sWCwFkzx)@uE&`RrfqkMXUwTEFUtvE{!WBO~ez8Q)r~^{aju3u9G=RQ?Q%spneF zU}<0Nm%iE`<6CRBzVcfzrg~C)Q2A9~t-rMI_M`Ua)^p?4`pOSu>wi|NGgoJ*@~ggD zU-@Ayj41{cL&|>>80$SlHDmhje%yF>JyoAR>nq)ZF|NryQZt^K@vSx2>#f!PxxO3k z=BxVjSzqZec6srCqn>9*%`>A<&3J0YS2gP^9mc|#%ba40{s@e5&(!XWR6Ogq)~(mC z)>HY)SMyndF|~(k4=UgF)qbnK+K*emT2JLGU+p)H9sb@THM6Q;Ju|Lft*7#p&;3w` zu`s3>Qa+iY9)YoHk1AjJoLN*YV zQq`kAn)PPWoCo`(@9HqN{AVT4aBWRyxF7ni4r5`A%!D!hdkQ%-S8wg_A>^6&*6hEk zXSl~W^Xg2P&-$}#_FvUu?9xunaCWBHpPGm3hp{jg#`M=;?sMhF^9*Lw?mXQ18SXK+ zXE%N}rr3`=4>vxH<+IXd#*GhSVa#R5jURzAx2I}6KTAf})p@x2oL3k-yvI1B(Y4D+ dHJf$GgMR{o&8wE)N&` z!%wg7e!aVS`NONbXSWZZ|MKR={q392-~W8`%k3Ne;pHb^TzdFTZU9Wd+(na zTXw${4;8){TmK$ca1+UC#>`lRmvdr!GsPOuxhpwlv^B84t?SYD$BeD-95rM9KC}9q zHDhM1HDj%rV!mv>=iHTQPWn-uUBlMLNBhlq_D}hYSK5rt|MeJ|h-xyz@${qGjF~Yq zW5(uZf_oa(WFV?JAIHT}L9&Y^0eAJv@C`t9c`+E2!_pUP)^RGYE4_b=MN9`R%% zs#!C~)6Z%%X2#6e)yaFGvd`7%G>p?@BU|)ntVE==auU%#2ktI0Iws0iWBbX1>-s zS`Xt@J<)iTuiEdy7}pf-8RMgx^S9Q~dKjpZ)xH^Pe^-)$s8)NMB_PrG;PMrST&O~FvcG6Iafa8m1aI_HD38DAN|?&h>WS8at|3#KdPCJI+~C1 zDj)sXwHZ5opRL+cG85Iz&++tU*JjL&RWm#zW9*4K2k0-YqxCX=c5TK^&#~4Fk`d;o zzcy{g%vd$kX3UJ4vH0(-)m=otwdS5NKmE1qHJoGIBekdWTkB{&j9;5JW9{!sGD3#w zuU(rlGgi&`8X05Hc#bTsSRZ|#r2voGe#y_YcpoXS~JCbTVRZ9=3UlW zbH1%>GuHkb+d5NRzZo-QWTLe;V`i*1Q_Qyo##(>970t(`kXvuk9GJ>$7SzqMBL(O0^4eM{#vWBPaQ(LEtE98X_qj-O4NF*8=ptQj+7W=#Kl zj6LGrLw{+l>Y=an()DNO8}S@dJ*LkdD4+4wTGc~e>FRo0=Qm^R?@I0wnNdFDsX2ah zZN|)4HKQ{!#-7afoU3~2E6sk4rbl2*^^iV$sC>rv)~a6mO7~umn%|7o-<8}0GNXLP zQ*-=i+Kid8Y9?o3j6Im`IcELTYCZIoX8zu~2V<&7^w}fjGk$cf) z0GUxfD#DteU|Y7#saLZM{caZ&tUiFPd)&#KsgADb~VR?VFmGh=4# WN%{A^)jiJU&)ky_W3%&T$NvYWrSRAQ literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/roads/road8.npy b/src/jaxatari/games/sprites/up_n_down/roads/road8.npy new file mode 100644 index 0000000000000000000000000000000000000000..6d7c053469dbc04f885297dd3f6daefa48b6e2ab GIT binary patch literal 76016 zcmeI5J#Q697=>@zG&CvFAl>c~$tfJ5iKr+jEG3E}rAUP=Fp8xJA{%J(QSt}U=D*&s zg(Y}a?CZmP?9A*zTEX%AG5el*cGg@F{QKiCKm7FT7rVE+KM&6@uAf~UzTY4IdUJX> z*&klKx_W(e{_^RotLGQf*PopKetq%&>({@X|8enN-#h*8!O8yg;r)~SgKtjufA94l zU+#9h9rwC5|w{+hAa^R0Q5gR2jFY|THPWAQWf zAs;1vJvwur!x()Y-6PJMTaWIa8S^=o`# zP517f8H>;5wV5_!@gDW=kr|8kc*&z4&M`8&G*8Lc!>3;#Te#?ls%y_MF;;s`)f48| zzTei?qaUiSJ&ze{<7`s)U1f^-%33*c&6pX>!O>C=GiJt?_{b46V`eM|M@v1-m>FB* zBS*}PnXw!kE%h*Cw`c6DN4sMS7k{6q>e@3*jMZLK^@RDg@3*z}=!dFn&tt~wKVzwT ztTM&?)b-VRGiJtAX4HH$X2w)z)O<5$##Cn1d^2XoRA$tCGiJtAX4HH$c6-K_K4+{M z;%{Mz^1+IsXu)wSm_W4HgDzxLjIYCV-H=6|Yhsjr(cGggYN zo4#hOKgQm@yL(KzZuY*`&ucK2;?GoD%bQl$>CHbg=5ws{vt-81m>G+|_bBbrP5+nw zJwCR0rM^2$&orH5rPql1x9(AN|7*uv_s@*Ac{VBBnz48fTlYxW?=^B;_iu!Aj2^GC z&n>M^oTn5 zbJ<%z`l0I7d0RZkIFsrL^Vx4}ntrG{)h{zP-n)`as7x`R{Zwtn%ov%XR+(ac4~+4i z>4B=*Z);lhi}|WwW^DX(j7*RjRkL5D&6pW8V{-;nPnaJ)ccaJYhpJV7dOXKCqv{d! z*>7u_eyBR~%Z%M!{qN63dsI5EGR6Gp{>+#eBQvTtWAUFwMIMbFSD9jdHT^PUW{gazOfjGR%$OM?6J$o!>^CA~oKf|N`Runf zO+Qp^#^&c3J=_{ooR@y6+KiboGDWR2#e6en#>f;knNhVFGh<|mT4jp)X3UI{DQYsK zYBOfW$P~586!XoP86#8FWJc8^GNyXKe0sb!O+Qpkzs%VD9HU2D(<)QUr(b5wjFAa4 zqiXguV`hv@s7x`R{YGGnGpd@Nu%7YJH0M<{{b0QroBv&FdPEOL)0|h;^n>+g%#4u< zRhzN-_gnNZ@`&-#H2qLD{a}5M=NM;Db>s=-qiNMI=F^WTZpI$Z{(mcaIGQFCs-_>T zXWWdLF)|WqGQ@nw&6pV@BatRU%xAm@#yEqjIWy}SA5C*!RnrgFGfr*B=AT{Aquw={ zP&MagJ>%47%#4v4YBCeWnQz9-7@46aGf|xRBQnMrRZm#oo94W#=De(DJkl*P#+g*j znOV>HXqxk?ntrgJ@kpDo`M(7kdDeQIOsJakv7Yfrn=vyMF%-qOz!+yt^@ROe*PK_? zoR9U4tGb1Aj5Dd4GqaxY(KP2(HT_^c8gsRC5>ls(I88c&K#*EEBGtKP@ zJ#S5OUR84*)-#@}*Eq*Gld3s0>lq(Sb6!=`57slDs?At`?@BVEYBIxm##6N!Gh=4V z-+z$-Rr4HTJ>$7`jdP4MshTshp7GvvG!NsdUfT~dw*I@4%#aDjRn302wHY&GrJ3u2 zF+R7YJz2YdG%w?-u03A~#?+aaPY`lI43ypz_UE}2#hD*z_RFnvoMX`pQJga|pYg5f$PdO^NQOl6AoW~@KPxQ7uVjH{aKsrojIsrSTu-cuB3zV-5R5c$ct zs#ETivDGQ%|)S2g>o+Kid8h&h!h=9{tp8RMQt%&2kpS9J}>IFovh&exs; v`kks%{WoLU{-yd|qdt`>)~DXH88c&KirS2g$4Kg)vi|b&&N1a;KUMz^`VKbE literal 0 HcmV?d00001 From 8ddc028f488b6d04fa481c7337d37349b4b2dfc6 Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Wed, 3 Dec 2025 00:00:17 +0100 Subject: [PATCH 12/76] remove offset, add moving backwards --- src/jaxatari/games/jax_upndown.py | 24 ++++-------------------- 1 file changed, 4 insertions(+), 20 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 9e30ea7cd..d86c3cc00 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -22,7 +22,7 @@ class UpNDownConstants(NamedTuple): FIRST_ROAD_LENGTH: int = 4 SECOND_ROAD_LENGTH: int = 4 FIRST_TRACK_CORNERS_X: chex.Array = jnp.array([30, 75, 128, 75, 21, 75, 131, 111, 150, 95, 150, 115, 150, 108, 150, 115, 115, 75, 18, 67, 38, 38, 20, 64, 30]) #get actual values - FIRST_TRACK_CORNERS_Y: chex.Array = jnp.array([105, 65, 7, -50, -98, -163, -222, -242, -277, -362, -420, -460, -492, -520, -565, -600, -633, -683, -733, -793, -820, -845, -867, -895, -928]) #get actual values + FIRST_TRACK_CORNERS_Y: chex.Array = jnp.array([0, -40, -98, -155, -203, -268, -327, -347, -382, -467, -525, -565, -597, -625, -670, -705, -738, -788, -838, -898, -925, -950, -972, -1000, -1033]) #get actual values SECOND_TRACK_CORNERS_X: chex.Array = FIRST_TRACK_CORNERS_X#jnp.array([20, 50]) #get actual values SECOND_TRACK_CORNERS_Y: chex.Array = FIRST_TRACK_CORNERS_Y#jnp.array([20, 50, ]) #get actual values PLAYER_SIZE: Tuple[int, int] = (4, 16) @@ -54,7 +54,6 @@ class UpNDownState(NamedTuple): is_on_road: chex.Array player_car: Car step_counter: chex.Array - road_reset: chex.Array @@ -302,19 +301,6 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: operand=None, ) - player_y = jax.lax.cond( - state.road_reset, - lambda s: 105.0, - lambda s: s, - operand=player_y, - ) - - road_reset = jax.lax.cond( - jnp.equal(player_y, -928), - lambda s: True, - lambda s: False, - operand=None, - ) @@ -325,11 +311,10 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: jump_cooldown=jump_cooldown, is_jumping=is_jumping, is_on_road=is_on_road, - road_reset=road_reset, player_car=Car( position=EntityPosition( x=player_x, - y=player_y, + y=-((player_y * -1) % 1036), width=state.player_car.position.width, height=state.player_car.position.height, ), @@ -351,11 +336,10 @@ def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: jump_cooldown=0, is_jumping=False, is_on_road=True, - road_reset=False, player_car=Car( position=EntityPosition( x=30, - y= 105, + y= 0, width=self.consts.PLAYER_SIZE[0], height=self.consts.PLAYER_SIZE[1], ), @@ -517,7 +501,7 @@ def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.n @partial(jax.jit, static_argnums=(0,)) def render(self, state): raster = self.jr.create_object_raster(self.BACKGROUND) - road_diff = (-state.player_car.position.y + 105) % self.complete_road_size + road_diff = (-state.player_car.position.y) % self.complete_road_size # Vectorized road rendering: compute all Y offsets, stamp via vmap, fold overlays. road_masks = self.SHAPE_MASKS["road"] # shape: (N, H, W) From 1c1ce94fca74a48ec7d4996f3e9f7dfc3453c4d2 Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Fri, 5 Dec 2025 21:45:35 +0100 Subject: [PATCH 13/76] add second road --- src/jaxatari/games/jax_upndown.py | 61 +++++++++++++++---------------- 1 file changed, 30 insertions(+), 31 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index d86c3cc00..a75917186 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -21,12 +21,11 @@ class UpNDownConstants(NamedTuple): LANDING_ZONE: int = 15 FIRST_ROAD_LENGTH: int = 4 SECOND_ROAD_LENGTH: int = 4 - FIRST_TRACK_CORNERS_X: chex.Array = jnp.array([30, 75, 128, 75, 21, 75, 131, 111, 150, 95, 150, 115, 150, 108, 150, 115, 115, 75, 18, 67, 38, 38, 20, 64, 30]) #get actual values - FIRST_TRACK_CORNERS_Y: chex.Array = jnp.array([0, -40, -98, -155, -203, -268, -327, -347, -382, -467, -525, -565, -597, -625, -670, -705, -738, -788, -838, -898, -925, -950, -972, -1000, -1033]) #get actual values - SECOND_TRACK_CORNERS_X: chex.Array = FIRST_TRACK_CORNERS_X#jnp.array([20, 50]) #get actual values - SECOND_TRACK_CORNERS_Y: chex.Array = FIRST_TRACK_CORNERS_Y#jnp.array([20, 50, ]) #get actual values + FIRST_TRACK_CORNERS_X: chex.Array = jnp.array([30, 75, 128, 75, 21, 75, 131, 111, 150, 95, 150, 115, 150, 108, 150, 115, 115, 75, 18, 38, 67, 38, 38, 20, 64, 30]) + TRACK_CORNERS_Y: chex.Array = jnp.array([0, -40, -98, -155, -203, -268, -327, -347, -382, -467, -525, -565, -597, -625, -670, -705, -738, -788, -838, -862, -898, -925, -950, -972, -1000, -1036]) + SECOND_TRACK_CORNERS_X: chex.Array = jnp.array([115, 75, 20, 75, 133, 75, 22, 37, 63, 27, 66, 30, 63, 24, 60, 38, 38, 75, 131, 111, 150, 118, 118, 98, 150, 115]) PLAYER_SIZE: Tuple[int, int] = (4, 16) - INITIAL_ROAD_POS_Y: int = 25 + INITIAL_ROAD_POS_Y: int = 25 @@ -93,8 +92,8 @@ def __init__(self, consts: UpNDownConstants = None, reward_funcs: list[callable] def _getSlopeAndB(self, state: UpNDownState) -> chex.Array: trackx, tracky, roadIndex = jax.lax.cond( state.player_car.current_road == 0, - lambda s: (self.consts.FIRST_TRACK_CORNERS_X, self.consts.FIRST_TRACK_CORNERS_Y, state.player_car.road_index_A), - lambda s: (self.consts.SECOND_TRACK_CORNERS_X, self.consts.SECOND_TRACK_CORNERS_Y, state.player_car.road_index_B), + lambda s: (self.consts.FIRST_TRACK_CORNERS_X, self.consts.TRACK_CORNERS_Y, state.player_car.road_index_A), + lambda s: (self.consts.SECOND_TRACK_CORNERS_X, self.consts.TRACK_CORNERS_Y, state.player_car.road_index_B), operand=None,) slope = jax.lax.cond( trackx[roadIndex+1] - trackx[roadIndex] != 0, @@ -106,18 +105,20 @@ def _getSlopeAndB(self, state: UpNDownState) -> chex.Array: return slope, b @partial(jax.jit, static_argnums=(0,)) - def _isOnLine(self, state: UpNDownState, new_position_x: chex.Array, new_position_y: chex.Array, player_speed: chex.Array) -> chex.Array: + def _isOnLine(self, state: UpNDownState, player_speed: chex.Array, turn: chex.Array) -> chex.Array: slope, b = self._getSlopeAndB(state) - jax.debug.print("slope: {}, b: {}", slope, b) - isOnLine = jnp.logical_or(jnp.logical_and(jnp.equal(slope, 300.0), jnp.equal(new_position_x, state.player_car.position.x)), jnp.less_equal(jnp.abs(jnp.round(jnp.subtract(new_position_y, slope * new_position_x + b))), player_speed)) - - jax.debug.print("isOnLine: {}", jnp.subtract(new_position_y, slope * new_position_x + b)) - return isOnLine + x_step = abs(jnp.subtract(state.player_car.position.y, slope * (state.player_car.position.x) + b)) + y_step = abs(jnp.subtract(state.player_car.position.y - player_speed, slope * state.player_car.position.x + b)) + prefer_y = jnp.less_equal(y_step, x_step) + return jnp.logical_or( + jnp.logical_and(turn == 1, prefer_y), + jnp.logical_and(turn == 2, jnp.logical_not(prefer_y)), + ) @partial(jax.jit, static_argnums=(0,)) def _landing_in_water(self, state: UpNDownState, new_position_x: chex.Array, new_position_y: chex.Array) -> chex.Array: - road_A_x = ((new_position_y - self.consts.FIRST_TRACK_CORNERS_Y[state.player_car.road_index_A]) / (self.consts.FIRST_TRACK_CORNERS_Y[state.player_car.road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_Y[state.player_car.road_index_A])) * (self.consts.FIRST_TRACK_CORNERS_X[state.player_car.road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_X[state.player_car.road_index_A]) + self.consts.FIRST_TRACK_CORNERS_X[state.player_car.road_index_A] - road_B_x = ((new_position_y - self.consts.SECOND_TRACK_CORNERS_Y[state.player_car.road_index_B]) / (self.consts.SECOND_TRACK_CORNERS_Y[state.player_car.road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_Y[state.player_car.road_index_B])) * (self.consts.SECOND_TRACK_CORNERS_X[state.player_car.road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[state.player_car.road_index_B]) + self.consts.SECOND_TRACK_CORNERS_X[state.player_car.road_index_B] + road_A_x = ((new_position_y - self.consts.TRACK_CORNERS_Y[state.player_car.road_index_A]) / (self.consts.TRACK_CORNERS_Y[state.player_car.road_index_A+1] - self.consts.TRACK_CORNERS_Y[state.player_car.road_index_A])) * (self.consts.FIRST_TRACK_CORNERS_X[state.player_car.road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_X[state.player_car.road_index_A]) + self.consts.FIRST_TRACK_CORNERS_X[state.player_car.road_index_A] + road_B_x = ((new_position_y - self.consts.TRACK_CORNERS_Y[state.player_car.road_index_B]) / (self.consts.TRACK_CORNERS_Y[state.player_car.road_index_B+1] - self.consts.TRACK_CORNERS_Y[state.player_car.road_index_B])) * (self.consts.SECOND_TRACK_CORNERS_X[state.player_car.road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[state.player_car.road_index_B]) + self.consts.SECOND_TRACK_CORNERS_X[state.player_car.road_index_B] distance_to_road_A = jnp.abs(new_position_x - road_A_x) distance_to_road_B = jnp.abs(new_position_x - road_B_x) landing_in_Water = jnp.logical_and(distance_to_road_A > self.consts.LANDING_ZONE, distance_to_road_B > self.consts.LANDING_ZONE) @@ -187,17 +188,15 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: operand=car_direction_x, ) - is_landing = jnp.logical_and(state.jump_cooldown == 1, jump_cooldown == 0) - ##calculate new position with speed (TODO: calculate better speed) player_y = jax.lax.cond( - state.step_counter % (16/ speed_divider) == 8 / speed_divider, + jnp.logical_and((state.step_counter % (16/ speed_divider) == 8 / speed_divider), player_speed != 0,), lambda s: jax.lax.cond( is_jumping, lambda s: state.player_car.position.y + jax.lax.abs(player_speed) / player_speed * -1, lambda s: jax.lax.cond( - self._isOnLine(state, state.player_car.position.x, s + jax.lax.abs(player_speed) / player_speed * -1, 1), + self._isOnLine(state, jax.lax.abs(player_speed) / player_speed, 1), lambda s: s + jax.lax.abs(player_speed) / player_speed * -1, lambda s: jnp.array(s, float), operand=state.player_car.position.y, @@ -207,12 +206,12 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: operand=state.player_car.position.y, ) player_x = jax.lax.cond( - state.step_counter % (16/ speed_divider) == 0, + jnp.logical_and((state.step_counter % (16/ speed_divider) == 0), player_speed != 0,), lambda s: jax.lax.cond( is_jumping, lambda s: s + jax.lax.abs(player_speed) / player_speed * car_direction_x, lambda s: jax.lax.cond( - self._isOnLine(state, s + jax.lax.abs(player_speed) / player_speed * car_direction_x, player_y, 1), + self._isOnLine(state, jax.lax.abs(player_speed) / player_speed, 2), lambda s: s + jax.lax.abs(player_speed) / player_speed * car_direction_x, lambda s: jnp.array(s, float), operand=state.player_car.position.x, @@ -224,7 +223,7 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: ##if y not on mx +b then no move - jax.debug.print("Player X: {}, Player Y: {}, car_direction_x: {}", player_x, player_y, car_direction_x) + landing_in_Water, between_roads, road_A_x, road_B_x = self._landing_in_water(state, player_x, player_y) landing_in_Water = jnp.logical_and(is_landing, landing_in_Water) @@ -251,18 +250,18 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: current_road == 2, lambda s: road_index_A, lambda s: jax.lax.cond( - self.consts.FIRST_TRACK_CORNERS_Y[road_index_A] < player_y, + self.consts.TRACK_CORNERS_Y[road_index_A] < player_y, lambda s: road_index_A - 1, lambda s: jax.lax.cond( - len(self.consts.FIRST_TRACK_CORNERS_Y) == road_index_A + 1, + len(self.consts.TRACK_CORNERS_Y) == road_index_A + 1, lambda s: jax.lax.cond( - self.consts.FIRST_TRACK_CORNERS_Y[0] > player_y, + self.consts.TRACK_CORNERS_Y[0] > player_y, lambda s: 0, lambda s: road_index_A, operand=None, ), lambda s: jax.lax.cond( - self.consts.FIRST_TRACK_CORNERS_Y[road_index_A+1] > player_y, + self.consts.TRACK_CORNERS_Y[road_index_A+1] > player_y, lambda s: road_index_A + 1, lambda s: road_index_A, operand=None, @@ -278,18 +277,18 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: current_road == 2, lambda s: road_index_B, lambda s: jax.lax.cond( - self.consts.SECOND_TRACK_CORNERS_Y[road_index_B] < player_y, + self.consts.TRACK_CORNERS_Y[road_index_B] < player_y, lambda s: road_index_B - 1, lambda s: jax.lax.cond( - len(self.consts.SECOND_TRACK_CORNERS_Y) == road_index_B + 1, + len(self.consts.TRACK_CORNERS_Y) == road_index_B + 1, lambda s: jax.lax.cond( - self.consts.SECOND_TRACK_CORNERS_Y[0] > player_y, + self.consts.TRACK_CORNERS_Y[0] > player_y, lambda s: 0, lambda s: road_index_B, operand=None, ), lambda s: jax.lax.cond( - self.consts.SECOND_TRACK_CORNERS_Y[road_index_B+1] > player_y, + self.consts.TRACK_CORNERS_Y[road_index_B+1] > player_y, lambda s: road_index_B + 1, lambda s: road_index_B, operand=None, @@ -301,7 +300,7 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: operand=None, ) - + jax.debug.print("Player X: {}, Player Y: {}, car_direction_x: {}", player_x, player_y, car_direction_x) #jax.debug.print("Player X: {}, Player Y: {}, on road: {}, jumping: {}, speed: {}, road index A: {}, road index B: {}, current road: {}", player_x, player_y, is_on_road, is_jumping, player_speed, road_index_A, road_index_B, current_road) From 70226a2300937ce03693ea1f11376270ea923c51 Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sat, 13 Dec 2025 20:40:38 +0100 Subject: [PATCH 14/76] add collectibles to the game --- src/jaxatari/games/jax_upndown.py | 624 +++++++++++++++++- .../games/sprites/up_n_down/all_flags_top.npy | Bin 0 -> 8752 bytes .../sprites/up_n_down/all_lives_bottom.npy | Bin 0 -> 3248 bytes .../games/sprites/up_n_down/balloon.npy | Bin 0 -> 640 bytes .../sprites/up_n_down/balloon_backup.npy | Bin 0 -> 548 bytes .../games/sprites/up_n_down/cherry.npy | Bin 0 -> 640 bytes .../games/sprites/up_n_down/cherry_backup.npy | Bin 0 -> 640 bytes .../games/sprites/up_n_down/flag_pole.npy | Bin 0 -> 192 bytes .../sprites/up_n_down/ice_cream_cone.npy | Bin 0 -> 640 bytes .../up_n_down/ice_cream_cone_backup.npy | Bin 0 -> 576 bytes .../games/sprites/up_n_down/lollypop.npy | Bin 0 -> 640 bytes .../sprites/up_n_down/lollypop_backup.npy | Bin 0 -> 512 bytes .../games/sprites/up_n_down/pink_flag.npy | Bin 0 -> 392 bytes .../games/sprites/up_n_down/purple_cherry.npy | Bin 0 -> 640 bytes 14 files changed, 612 insertions(+), 12 deletions(-) create mode 100644 src/jaxatari/games/sprites/up_n_down/all_flags_top.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/all_lives_bottom.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/balloon.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/balloon_backup.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/cherry.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/cherry_backup.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/flag_pole.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/ice_cream_cone.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/ice_cream_cone_backup.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/lollypop.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/lollypop_backup.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/pink_flag.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/purple_cherry.npy diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index a75917186..e469364bf 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -26,6 +26,46 @@ class UpNDownConstants(NamedTuple): SECOND_TRACK_CORNERS_X: chex.Array = jnp.array([115, 75, 20, 75, 133, 75, 22, 37, 63, 27, 66, 30, 63, 24, 60, 38, 38, 75, 131, 111, 150, 118, 118, 98, 150, 115]) PLAYER_SIZE: Tuple[int, int] = (4, 16) INITIAL_ROAD_POS_Y: int = 25 + # Flag constants - 8 flags with different colors matching the top row + NUM_FLAGS: int = 8 + FLAG_SIZE: Tuple[int, int] = (11, 6) # height, width of the flag sprite + FLAG_POLE_SIZE: Tuple[int, int] = (7, 2) # height, width of the pole sprite + # Flag colors as RGBA values (matching the top row from left to right) + FLAG_COLORS: chex.Array = jnp.array([ + [184, 50, 50, 255], # Red + [181, 83, 40, 255], # Orange + [162, 98, 33, 255], # Dark orange + [134, 134, 29, 255], # Yellow/olive + [200, 72, 72, 255], # Pink (original) + [168, 48, 143, 255], # Magenta + [125, 48, 173, 255], # Purple + [78, 50, 181, 255], # Blue + ]) + # Top display positions for each flag (x coordinates where blackout squares appear) + FLAG_TOP_X_POSITIONS: chex.Array = jnp.array([13, 30, 47, 64, 82, 98, 118, 132]) + FLAG_TOP_Y: int = 20 + FLAG_BLACKOUT_SIZE: Tuple[int, int] = (14, 14) # Size of blackout square + FLAG_COLLECTION_SCORE: int = 75 # Points awarded for collecting a flag + PICKUP_SCORE: int = 100 # Points awarded for jumping on a pickup truck + FLAG_CARRIER_SCORE: int = 125 # Points awarded for jumping on a flag carrier + CAMARO_SCORE: int = 150 # Points awarded for jumping on a camaro + TRUCK_SCORE: int = 175 # Points awarded for jumping on a truck + # Collectible constants - unified dynamic spawning + MAX_COLLECTIBLES: int = 2 # Maximum collectibles that can exist at once (pool of mixed types) + COLLECTIBLE_SIZE: Tuple[int, int] = (8, 8) # height, width of collectible sprite + COLLECTIBLE_SPAWN_INTERVAL: int = 200 # Steps between spawn attempts + COLLECTIBLE_DESPAWN_DISTANCE: int = 500 # Distance beyond which collectibles despawn + # Collectible types (indices for type field) + COLLECTIBLE_TYPE_CHERRY: int = 0 + COLLECTIBLE_TYPE_BALLOON: int = 1 + COLLECTIBLE_TYPE_LOLLYPOP: int = 2 + COLLECTIBLE_TYPE_ICE_CREAM: int = 3 + # Collectible type spawn probabilities (must sum to 100) + COLLECTIBLE_SPAWN_PROBABILITIES: chex.Array = jnp.array([40, 20, 20, 20], dtype=jnp.int32) # Cherry: 40%, Balloon: 20%, Lollypop: 20%, IceCream: 20% + # Collectible type scores + COLLECTIBLE_SCORES: chex.Array = jnp.array([50, 65, 70, 75], dtype=jnp.int32) # [cherry, balloon, lollypop, ice_cream] + # Shared collectible colors + COLLECTIBLE_COLORS: chex.Array = FLAG_COLORS @@ -45,6 +85,27 @@ class Car(NamedTuple): road_index_B: chex.Array direction_x: chex.Array +class Flag(NamedTuple): + """Represents a collectible flag on the road.""" + y: chex.Array # Y position in world coordinates (like player_car.position.y) + road: chex.Array # Which road the flag is on (0 or 1) + road_segment: chex.Array # Which road segment index the flag is on + color_idx: chex.Array # Index into FLAG_COLORS array + collected: chex.Array # Whether this flag has been collected + +class Collectible(NamedTuple): + """Represents a dynamically spawning collectible item on the road. + + Can be any type: cherry (0), balloon (1), lollypop (2), or ice cream (3). + The type determines the sprite and point value. + """ + y: chex.Array # Y position in world coordinates + x: chex.Array # X position on the road + road: chex.Array # Which road the collectible is on (0 or 1) + color_idx: chex.Array # Index into COLLECTIBLE_COLORS array + type_id: chex.Array # Type of collectible (0=cherry, 1=balloon, 2=lollypop, 3=ice_cream) + active: chex.Array # Whether this collectible slot is active (spawned) + class UpNDownState(NamedTuple): score: chex.Array difficulty: chex.Array @@ -53,6 +114,12 @@ class UpNDownState(NamedTuple): is_on_road: chex.Array player_car: Car step_counter: chex.Array + # Flag state - tracks all 8 flags + flags: Flag # Contains arrays of size NUM_FLAGS for each field + flags_collected_mask: chex.Array # Boolean mask of which flag colors have been collected (size NUM_FLAGS) + # Collectible state - dynamic spawning (mixed types: cherry, balloon, lollypop, ice cream) + collectibles: Collectible # Contains arrays of size MAX_COLLECTIBLES for each field + collectible_spawn_timer: chex.Array # Counter for collectible spawn timing @@ -60,12 +127,6 @@ class UpNDownState(NamedTuple): class UpNDownObservation(NamedTuple): player: EntityPosition -class Collectible(NamedTuple): - position: EntityPosition - type: chex.Array - value: chex.Array - - class UpNDownInfo(NamedTuple): time: jnp.ndarray @@ -104,6 +165,23 @@ def _getSlopeAndB(self, state: UpNDownState) -> chex.Array: b = tracky[roadIndex] - slope * trackx[roadIndex] return slope, b + @partial(jax.jit, static_argnums=(0,)) + def _get_x_on_road(self, y: chex.Array, road_segment: chex.Array, track_corners_x: chex.Array) -> chex.Array: + """Calculate the X position on a road given a Y coordinate and road segment.""" + y1 = self.consts.TRACK_CORNERS_Y[road_segment] + y2 = self.consts.TRACK_CORNERS_Y[road_segment + 1] + x1 = track_corners_x[road_segment] + x2 = track_corners_x[road_segment + 1] + + # Linear interpolation: x = x1 + (y - y1) * (x2 - x1) / (y2 - y1) + t = jax.lax.cond( + y2 != y1, + lambda _: (y - y1) / (y2 - y1), + lambda _: 0.0, + operand=None, + ) + return x1 + t * (x2 - x1) + @partial(jax.jit, static_argnums=(0,)) def _isOnLine(self, state: UpNDownState, player_speed: chex.Array, turn: chex.Array) -> chex.Array: slope, b = self._getSlopeAndB(state) @@ -125,6 +203,223 @@ def _landing_in_water(self, state: UpNDownState, new_position_x: chex.Array, new between_roads = jnp.logical_and(new_position_x > jnp.minimum(road_A_x, road_B_x), new_position_x < jnp.maximum(road_A_x, road_B_x)) return landing_in_Water, between_roads, road_A_x, road_B_x + def _flag_step(self, state: UpNDownState, new_player_y: chex.Array, player_x: chex.Array, current_road: chex.Array) -> Tuple[Flag, chex.Array, chex.Array]: + """Update flag collection state and score. + + Args: + state: Current game state + new_player_y: Updated player Y position after movement + player_x: Current player X position + current_road: Current road player is on + + Returns: + Tuple of (updated_flags, score_delta, flags_collected_mask) + """ + # Check collision for each flag + def check_flag_collision(flag_idx): + flag_y = state.flags.y[flag_idx] + flag_road = state.flags.road[flag_idx] + flag_collected = state.flags.collected[flag_idx] + + # Calculate flag X position on its road + flag_segment = state.flags.road_segment[flag_idx] + flag_x = jax.lax.cond( + flag_road == 0, + lambda _: self._get_x_on_road(flag_y, flag_segment, self.consts.FIRST_TRACK_CORNERS_X), + lambda _: self._get_x_on_road(flag_y, flag_segment, self.consts.SECOND_TRACK_CORNERS_X), + operand=None, + ) + + # Check if player is close enough to collect the flag + y_distance = jnp.abs(new_player_y - flag_y) + x_distance = jnp.abs(player_x - flag_x) + same_road = jnp.logical_or( + jnp.logical_and(current_road == 0, flag_road == 0), + jnp.logical_and(current_road == 1, flag_road == 1), + ) + + collision = jnp.logical_and( + jnp.logical_and(y_distance < 5, x_distance < 5), #change the distance threshold if needed + jnp.logical_and(same_road, ~flag_collected) + ) + return collision + + new_collections = jax.vmap(check_flag_collision)(jnp.arange(self.consts.NUM_FLAGS)) + + # Update flags collected state + new_flags_collected = jnp.logical_or(state.flags.collected, new_collections) + new_flags_collected_mask = jnp.logical_or(state.flags_collected_mask, new_collections) + + # Update score based on collected flags + flag_score = jnp.sum(new_collections.astype(jnp.int32) * self.consts.FLAG_COLLECTION_SCORE) + + new_flags = Flag( + y=state.flags.y, + road=state.flags.road, + road_segment=state.flags.road_segment, + color_idx=state.flags.color_idx, + collected=new_flags_collected, + ) + + return new_flags, flag_score, new_flags_collected_mask + + def _collectible_step(self, state: UpNDownState, new_player_y: chex.Array, player_x: chex.Array, current_road: chex.Array) -> Tuple[Collectible, chex.Array, chex.Array]: + """Update collectible spawning, despawning, and collection (unified for all types). + + Handles mixed-type collectibles (cherry, balloon, lollypop, ice cream) in a single pool. + Type is randomized on spawn with probabilities defined in COLLECTIBLE_SPAWN_PROBABILITIES. + + Args: + state: Current game state + new_player_y: Updated player Y position after movement + player_x: Current player X position + current_road: Current road player is on + + Returns: + Tuple of (updated_collectibles, score_delta, new_spawn_timer) + """ + # Collectible spawning logic - decrement timer and spawn when ready + new_collectible_timer = jax.lax.cond( + state.collectible_spawn_timer <= 0, + lambda _: self.consts.COLLECTIBLE_SPAWN_INTERVAL, + lambda _: state.collectible_spawn_timer - 1, + operand=None, + ) + + # Attempt to spawn when timer hits 0 + should_spawn = state.collectible_spawn_timer <= 0 + + # Find first inactive collectible slot + def find_inactive_idx(collectibles_in): + inactive_mask = ~collectibles_in.active + first_inactive = jnp.argmax(inactive_mask.astype(jnp.int32)) + has_inactive = jnp.any(inactive_mask) + return jax.lax.cond( + has_inactive, + lambda _: first_inactive, + lambda _: jnp.array(0, dtype=jnp.int32), + operand=None, + ), has_inactive + + spawn_idx, has_inactive_slot = find_inactive_idx(state.collectibles) + + # Generate random spawn position using fold_in for deterministic randomness + base_key = jax.random.PRNGKey(0) + key_for_spawn = jax.random.fold_in(base_key, state.step_counter) + key1, key2, key3, key4, key5 = jax.random.split(key_for_spawn, 5) + y_spawn = jax.random.uniform(key1, minval=-900.0, maxval=-100.0) + road_spawn = jnp.array(jax.random.randint(key2, shape=(), minval=0, maxval=2), dtype=jnp.int32) + color_spawn = jnp.array(jax.random.randint(key3, shape=(), minval=0, maxval=len(self.consts.COLLECTIBLE_COLORS)), dtype=jnp.int32) + + # Randomly select collectible type based on spawn probabilities + # Convert probabilities (%) to cumulative distribution for sampling + rand_type = jax.random.uniform(key4, minval=0.0, maxval=100.0) + + # Use cumulative probabilities: cherry [0-40], balloon [40-60], lollypop [60-80], ice_cream [80-100] + def select_type(rand_val): + # Returns 0=cherry, 1=balloon, 2=lollypop, 3=ice_cream + type_id = jnp.where( + rand_val < self.consts.COLLECTIBLE_SPAWN_PROBABILITIES[0], + jnp.int32(self.consts.COLLECTIBLE_TYPE_CHERRY), + jnp.where( + rand_val < self.consts.COLLECTIBLE_SPAWN_PROBABILITIES[1], + jnp.int32(self.consts.COLLECTIBLE_TYPE_BALLOON), + jnp.where( + rand_val < self.consts.COLLECTIBLE_SPAWN_PROBABILITIES[2], + jnp.int32(self.consts.COLLECTIBLE_TYPE_LOLLYPOP), + jnp.int32(self.consts.COLLECTIBLE_TYPE_ICE_CREAM) + ) + ) + ) + return type_id + + type_id_spawn = select_type(rand_type) + + # Calculate X position on road + def get_road_segment(y): + segments = jnp.sum(self.consts.TRACK_CORNERS_Y > y) + return jnp.clip(segments - 1, 0, len(self.consts.TRACK_CORNERS_Y) - 2) + + segment_spawn = get_road_segment(y_spawn) + x_spawn = jax.lax.cond( + road_spawn == 0, + lambda _: self._get_x_on_road(y_spawn, segment_spawn, self.consts.FIRST_TRACK_CORNERS_X), + lambda _: self._get_x_on_road(y_spawn, segment_spawn, self.consts.SECOND_TRACK_CORNERS_X), + operand=None, + ) + + # Create mask for which collectibles to update + update_mask = (jnp.arange(self.consts.MAX_COLLECTIBLES) == spawn_idx) & should_spawn & has_inactive_slot + + # Update collectibles with proper masking + updated_collectibles = Collectible( + y=jnp.where(update_mask, y_spawn, state.collectibles.y), + x=jnp.where(update_mask, x_spawn, state.collectibles.x), + road=jnp.where(update_mask, road_spawn, state.collectibles.road), + color_idx=jnp.where(update_mask, color_spawn, state.collectibles.color_idx), + type_id=jnp.where(update_mask, type_id_spawn, state.collectibles.type_id), + active=jnp.where(update_mask, True, state.collectibles.active), + ) + + # Despawn logic - remove collectibles too far from player + def check_despawn(idx): + c_y = updated_collectibles.y[idx] + c_active = updated_collectibles.active[idx] + distance = jnp.abs(new_player_y - c_y) + too_far = distance > self.consts.COLLECTIBLE_DESPAWN_DISTANCE + should_despawn = jnp.logical_and(c_active, too_far) + return should_despawn + + despawn_mask = jax.vmap(check_despawn)(jnp.arange(self.consts.MAX_COLLECTIBLES)) + new_active = jnp.logical_and(updated_collectibles.active, ~despawn_mask) + + # Collision detection + def check_collision(idx): + c_y = updated_collectibles.y[idx] + c_x = updated_collectibles.x[idx] + c_road = updated_collectibles.road[idx] + c_active = updated_collectibles.active[idx] + + y_distance = jnp.abs(new_player_y - c_y) + x_distance = jnp.abs(player_x - c_x) + same_road = jnp.logical_or( + jnp.logical_and(current_road == 0, c_road == 0), + jnp.logical_and(current_road == 1, c_road == 1), + ) + + collision = jnp.logical_and( + jnp.logical_and(y_distance < 5, x_distance < 5), + jnp.logical_and(same_road, c_active) + ) + return collision + + collections = jax.vmap(check_collision)(jnp.arange(self.consts.MAX_COLLECTIBLES)) + + # Deactivate collected items + new_active = jnp.logical_and(new_active, ~collections) + + # Update score - use type_id to look up score value + def get_collection_score(idx): + is_collected = collections[idx] + type_id = updated_collectibles.type_id[idx] + # Look up score based on type_id using array indexing + score = self.consts.COLLECTIBLE_SCORES[type_id] + return jnp.where(is_collected, score, 0) + + score_array = jax.vmap(get_collection_score)(jnp.arange(self.consts.MAX_COLLECTIBLES)) + score_delta = jnp.sum(score_array) + + updated_collectibles = Collectible( + y=updated_collectibles.y, + x=updated_collectibles.x, + road=updated_collectibles.road, + color_idx=updated_collectibles.color_idx, + type_id=updated_collectibles.type_id, + active=new_active, + ) + + return updated_collectibles, score_delta, new_collectible_timer + def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: up = jnp.logical_or(action == Action.UP, action == Action.UPFIRE) down = jnp.logical_or(action == Action.DOWN, action == Action.DOWNFIRE) @@ -302,8 +597,9 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: jax.debug.print("Player X: {}, Player Y: {}, car_direction_x: {}", player_x, player_y, car_direction_x) - - #jax.debug.print("Player X: {}, Player Y: {}, on road: {}, jumping: {}, speed: {}, road index A: {}, road index B: {}, current road: {}", player_x, player_y, is_on_road, is_jumping, player_speed, road_index_A, road_index_B, current_road) + # Calculate new player y position after wrapping + new_player_y = -((player_y * -1) % 1036) + return UpNDownState( score=state.score, difficulty=state.difficulty, @@ -313,7 +609,7 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: player_car=Car( position=EntityPosition( x=player_x, - y=-((player_y * -1) % 1036), + y=new_player_y, width=state.player_car.position.width, height=state.player_car.position.height, ), @@ -325,10 +621,105 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: type=state.player_car.type, ), step_counter=state.step_counter + 1, + flags=state.flags, + flags_collected_mask=state.flags_collected_mask, + collectibles=state.collectibles, + collectible_spawn_timer=state.collectible_spawn_timer, + ) + + def _flag_step_main(self, state: UpNDownState) -> UpNDownState: + """Update flag collection state and score.""" + new_player_y = state.player_car.position.y + player_x = state.player_car.position.x + current_road = state.player_car.current_road + + new_flags, flag_score, new_flags_collected_mask = self._flag_step( + state, new_player_y, player_x, current_road + ) + + return UpNDownState( + score=state.score + flag_score, + difficulty=state.difficulty, + jump_cooldown=state.jump_cooldown, + is_jumping=state.is_jumping, + is_on_road=state.is_on_road, + player_car=state.player_car, + step_counter=state.step_counter, + flags=new_flags, + flags_collected_mask=new_flags_collected_mask, + collectibles=state.collectibles, + collectible_spawn_timer=state.collectible_spawn_timer, + ) + + def _collectible_step_main(self, state: UpNDownState) -> UpNDownState: + """Update collectible spawning, despawning, and collection.""" + new_player_y = state.player_car.position.y + player_x = state.player_car.position.x + current_road = state.player_car.current_road + + updated_collectibles, collectible_score, new_collectible_timer = self._collectible_step( + state, new_player_y, player_x, current_road + ) + + return UpNDownState( + score=state.score + collectible_score, + difficulty=state.difficulty, + jump_cooldown=state.jump_cooldown, + is_jumping=state.is_jumping, + is_on_road=state.is_on_road, + player_car=state.player_car, + step_counter=state.step_counter, + flags=state.flags, + flags_collected_mask=state.flags_collected_mask, + collectibles=updated_collectibles, + collectible_spawn_timer=new_collectible_timer, ) def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: + # Initialize flags at random positions along the track + # Use key for randomness if provided, otherwise use default positions + if key is None: + key = jax.random.PRNGKey(42) + + # Evenly spread flags along the track with small jitter + key, subkey = jax.random.split(key) + base_y = jnp.linspace(-900.0, -100.0, self.consts.NUM_FLAGS) + jitter = jax.random.uniform(subkey, shape=(self.consts.NUM_FLAGS,), minval=-40.0, maxval=40.0) + flag_y_offsets = base_y + jitter + + # Alternate roads 0/1 for variety + flag_roads = jnp.arange(self.consts.NUM_FLAGS) % 2 + + # Calculate which road segment each flag is on based on Y position + def get_road_segment(y): + # Find the segment where TRACK_CORNERS_Y[i] > y >= TRACK_CORNERS_Y[i+1] + segments = jnp.sum(self.consts.TRACK_CORNERS_Y > y) + return jnp.clip(segments - 1, 0, len(self.consts.TRACK_CORNERS_Y) - 2) + + flag_segments = jax.vmap(get_road_segment)(flag_y_offsets) + + # Each flag color index corresponds to its position (0-7) + flag_color_indices = jnp.arange(self.consts.NUM_FLAGS) + + flags = Flag( + y=flag_y_offsets, + road=flag_roads, + road_segment=flag_segments, + color_idx=flag_color_indices, + collected=jnp.zeros(self.consts.NUM_FLAGS, dtype=jnp.bool_), + ) + + # Initialize collectibles as all inactive (will spawn dynamically with mixed types) + collectibles = Collectible( + y=jnp.zeros(self.consts.MAX_COLLECTIBLES), + x=jnp.zeros(self.consts.MAX_COLLECTIBLES), + road=jnp.zeros(self.consts.MAX_COLLECTIBLES, dtype=jnp.int32), + color_idx=jnp.zeros(self.consts.MAX_COLLECTIBLES, dtype=jnp.int32), + type_id=jnp.zeros(self.consts.MAX_COLLECTIBLES, dtype=jnp.int32), + active=jnp.zeros(self.consts.MAX_COLLECTIBLES, dtype=jnp.bool_), + ) + state = UpNDownState( score=0, difficulty=self.consts.DIFFICULTIES[0], @@ -350,6 +741,10 @@ def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: type=0, ), step_counter=jnp.array(0), + flags=flags, + flags_collected_mask=jnp.zeros(self.consts.NUM_FLAGS, dtype=jnp.bool_), + collectibles=collectibles, + collectible_spawn_timer=jnp.array(0, dtype=jnp.int32), ) initial_obs = self._get_observation(state) return initial_obs, state @@ -358,6 +753,8 @@ def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: def step(self, state: UpNDownState, action: chex.Array) -> Tuple[UpNDownObservation, UpNDownState, float, bool, UpNDownInfo]: previous_state = state state = self._player_step(state, action) + state = self._flag_step_main(state) + state = self._collectible_step_main(state) done = self._get_done(state) env_reward = self._get_reward(previous_state, state) @@ -439,9 +836,10 @@ def __init__(self, consts: UpNDownConstants = None): top_block = self._createBackgroundSprite((25, self.config.game_dimensions[1])) bottom_block = self._createBackgroundSprite((16, self.config.game_dimensions[1])) temp_pointer = self._createBackgroundSprite((1, 1)) + blackout_square = self._createBackgroundSprite(self.consts.FLAG_BLACKOUT_SIZE) # 2. Update asset config to include both walls - asset_config, road_files = self._get_asset_config(background, top_block, bottom_block, temp_pointer) + asset_config, road_files = self._get_asset_config(background, top_block, bottom_block, temp_pointer, blackout_square) sprite_path = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/" # 3. Make a single call to the setup function @@ -459,6 +857,28 @@ def __init__(self, consts: UpNDownConstants = None): repeats = max(1, math.ceil(self.view_height / road_cycle) + 2) self._road_tile_offsets = jnp.arange(-repeats, repeats + 1, dtype=jnp.int32) * jnp.int32(self.complete_road_size) self._num_road_tiles = int(self._road_tile_offsets.shape[0]) + + # Precompute flag mask data for recoloring without special-casing pink + self.flag_base_mask = self.SHAPE_MASKS["pink_flag"] + self.flag_solid_mask = self.flag_base_mask != self.jr.TRANSPARENT_ID + self.flag_palette_ids = self._compute_flag_palette_ids() + + # Precompute collectible mask data for recoloring (unified for all types: cherry, balloon, lollypop, ice cream) + self.cherry_base_mask = self.SHAPE_MASKS["cherry"] + self.cherry_solid_mask = self.cherry_base_mask != self.jr.TRANSPARENT_ID + self.cherry_palette_ids = self._compute_flag_palette_ids() + + self.balloon_base_mask = self.SHAPE_MASKS["balloon"] + self.balloon_solid_mask = self.balloon_base_mask != self.jr.TRANSPARENT_ID + self.balloon_palette_ids = self._compute_flag_palette_ids() + + self.lollypop_base_mask = self.SHAPE_MASKS["lollypop"] + self.lollypop_solid_mask = self.lollypop_base_mask != self.jr.TRANSPARENT_ID + self.lollypop_palette_ids = self._compute_flag_palette_ids() + + self.ice_cream_base_mask = self.SHAPE_MASKS["ice_cream"] + self.ice_cream_solid_mask = self.ice_cream_base_mask != self.jr.TRANSPARENT_ID + self.ice_cream_palette_ids = self._compute_flag_palette_ids() def _createBackgroundSprite(self, dimensions: Tuple[int, int]) -> jnp.ndarray: """Creates a procedural background sprite for the game.""" @@ -479,8 +899,21 @@ def _get_road_sprite_sizes(self, road_files: list[str]) -> list: complete_size = int(sum(sizes)) jax.debug.print("Complete road size: {}", complete_size) return sizes, complete_size + + def _find_palette_id(self, rgba: jnp.ndarray) -> int: + """Return palette index for an RGBA color, falling back to first entry if missing.""" + color_rgb = rgba[:3] + palette_rgb = self.PALETTE[:, :3] + matches = jnp.all(palette_rgb == color_rgb, axis=1) + found = jnp.argmax(matches) + # If no match, fallback to 0 (background) to avoid crashes + return int(found) + + def _compute_flag_palette_ids(self) -> jnp.ndarray: + """Precompute palette indices for each flag color without special-casing pink.""" + return jnp.array([self._find_palette_id(color) for color in self.consts.FLAG_COLORS], dtype=jnp.int32) - def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.ndarray, bottomBlockSprite: jnp.ndarray, tempPointer: jnp.ndarray) -> tuple[list, list[str]]: + def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.ndarray, bottomBlockSprite: jnp.ndarray, tempPointer: jnp.ndarray, blackoutSquare: jnp.ndarray) -> tuple[list, list[str]]: """Returns the asset manifest and ordered road files.""" road_dir = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/roads" road_files = sorted( @@ -494,7 +927,16 @@ def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.n {'name': 'player', 'type': 'single', 'file': 'player_car.npy'}, {'name': 'wall_top', 'type': 'procedural', 'data': topBlockSprite}, {'name': 'wall_bottom', 'type': 'procedural', 'data': bottomBlockSprite}, + {'name': 'all_flags_top', 'type': 'single', 'file': 'all_flags_top.npy'}, + {'name': 'all_lives_bottom', 'type': 'single', 'file': 'all_lives_bottom.npy'}, + {'name': 'pink_flag', 'type': 'single', 'file': 'pink_flag.npy'}, + {'name': 'flag_pole', 'type': 'single', 'file': 'flag_pole.npy'}, + {'name': 'cherry', 'type': 'single', 'file': 'cherry.npy'}, + {'name': 'balloon', 'type': 'single', 'file': 'balloon.npy'}, + {'name': 'lollypop', 'type': 'single', 'file': 'lollypop.npy'}, + {'name': 'ice_cream', 'type': 'single', 'file': 'ice_cream_cone.npy'}, {'name': 'tempPointer', 'type': 'procedural', 'data': tempPointer}, + {'name': 'blackout_square', 'type': 'procedural', 'data': blackoutSquare}, ], roads @partial(jax.jit, static_argnums=(0,)) @@ -556,7 +998,165 @@ def combine(i, acc): wall_bottom_mask = self.SHAPE_MASKS["wall_bottom"] raster = self.jr.render_at(raster, 0, 210 - wall_bottom_mask.shape[0], wall_bottom_mask) + all_flags_top_mask = self.SHAPE_MASKS["all_flags_top"] + raster = self.jr.render_at(raster, 10, 20, all_flags_top_mask) + + # Render flags on the road + flag_pole_mask = self.SHAPE_MASKS["flag_pole"] + + def render_flag(carry, flag_idx): + raster = carry + flag_y = state.flags.y[flag_idx] + flag_road = state.flags.road[flag_idx] + flag_segment = state.flags.road_segment[flag_idx] + flag_collected = state.flags.collected[flag_idx] + flag_color_idx = state.flags.color_idx[flag_idx] + + # Calculate flag X position on its road + flag_x = jax.lax.cond( + flag_road == 0, + lambda _: self._get_flag_x_on_road(flag_y, flag_segment, self.consts.FIRST_TRACK_CORNERS_X), + lambda _: self._get_flag_x_on_road(flag_y, flag_segment, self.consts.SECOND_TRACK_CORNERS_X), + operand=None, + ) + + # Calculate screen Y position relative to player + # The player is always rendered at Y=105, so flags scroll based on player position + screen_y = 105 + (flag_y - state.player_car.position.y) + + # Check if flag is visible on screen and not collected + is_visible = jnp.logical_and( + jnp.logical_and(screen_y > 25, screen_y < 195), + ~flag_collected + ) + + # Colorize the base flag mask + color_id = self.flag_palette_ids[flag_color_idx] + colored_flag_mask = jnp.where( + self.flag_solid_mask, + color_id, + self.flag_base_mask, + ) + + # Render flag if visible + raster = jax.lax.cond( + is_visible, + lambda r: self.jr.render_at( + self.jr.render_at(r, flag_x.astype(jnp.int32), screen_y.astype(jnp.int32), colored_flag_mask), + (flag_x + 5).astype(jnp.int32), screen_y.astype(jnp.int32), flag_pole_mask + ), + lambda r: r, + operand=raster, + ) + return raster, None + + raster, _ = jax.lax.scan(render_flag, raster, jnp.arange(self.consts.NUM_FLAGS)) + + # Black out collected flags at the top + blackout_mask = self.SHAPE_MASKS["blackout_square"] + + def render_blackout(carry, flag_idx): + raster = carry + flag_collected = state.flags_collected_mask[flag_idx] + blackout_x = self.consts.FLAG_TOP_X_POSITIONS[flag_idx] + blackout_y = self.consts.FLAG_TOP_Y + + raster = jax.lax.cond( + flag_collected, + lambda r: self.jr.render_at(r, blackout_x, blackout_y, blackout_mask), + lambda r: r, + operand=raster, + ) + return raster, None + + raster, _ = jax.lax.scan(render_blackout, raster, jnp.arange(self.consts.NUM_FLAGS)) + + # Render collectibles (unified for all types: cherry, balloon, lollypop, ice cream) + def render_collectible(carry, collectible_idx): + raster = carry + collectible_y = state.collectibles.y[collectible_idx] + collectible_x = state.collectibles.x[collectible_idx] + collectible_active = state.collectibles.active[collectible_idx] + collectible_color_idx = state.collectibles.color_idx[collectible_idx] + collectible_type_id = state.collectibles.type_id[collectible_idx] + + # Calculate screen Y position relative to player + screen_y = 105 + (collectible_y - state.player_car.position.y) + + # Check if collectible is visible on screen and active + is_visible = jnp.logical_and( + jnp.logical_and(screen_y > 25, screen_y < 195), + collectible_active + ) + + # Select sprite based on type_id + # type_id: 0=cherry, 1=balloon, 2=lollypop, 3=ice_cream + def get_sprite_and_mask(type_id): + cherry_result = (self.cherry_base_mask, self.cherry_solid_mask, self.cherry_palette_ids) + balloon_result = (self.balloon_base_mask, self.balloon_solid_mask, self.balloon_palette_ids) + lollypop_result = (self.lollypop_base_mask, self.lollypop_solid_mask, self.lollypop_palette_ids) + ice_cream_result = (self.ice_cream_base_mask, self.ice_cream_solid_mask, self.ice_cream_palette_ids) + + # Use conditional branching to select sprite + result = jax.lax.cond( + type_id == self.consts.COLLECTIBLE_TYPE_CHERRY, + lambda _: cherry_result, + lambda _: jax.lax.cond( + type_id == self.consts.COLLECTIBLE_TYPE_BALLOON, + lambda _: balloon_result, + lambda _: jax.lax.cond( + type_id == self.consts.COLLECTIBLE_TYPE_LOLLYPOP, + lambda _: lollypop_result, + lambda _: ice_cream_result, + operand=None, + ), + operand=None, + ), + operand=None, + ) + return result + + base_mask, solid_mask, palette_ids = get_sprite_and_mask(collectible_type_id) + + # Only colorize inner pixels, keep black edges (palette ID 0 is black) + color_id = palette_ids[collectible_color_idx] + colored_mask = jnp.where( + (base_mask != self.jr.TRANSPARENT_ID) & (base_mask != 0), + color_id, + base_mask, + ) + + # Render collectible if visible + raster = jax.lax.cond( + is_visible, + lambda r: self.jr.render_at(r, collectible_x.astype(jnp.int32), screen_y.astype(jnp.int32), colored_mask), + lambda r: r, + operand=raster, + ) + return raster, None + + raster, _ = jax.lax.scan(render_collectible, raster, jnp.arange(self.consts.MAX_COLLECTIBLES)) + + all_lives_bottom_mask = self.SHAPE_MASKS["all_lives_bottom"] + raster = self.jr.render_at(raster, 10, 195, all_lives_bottom_mask) + wall_bottom_mask = self.SHAPE_MASKS["tempPointer"] raster = self.jr.render_at(raster, 140, 25, wall_bottom_mask) - return self.jr.render_from_palette(raster, self.PALETTE) \ No newline at end of file + return self.jr.render_from_palette(raster, self.PALETTE) + + def _get_flag_x_on_road(self, y: chex.Array, road_segment: chex.Array, track_corners_x: chex.Array) -> chex.Array: + """Calculate the X position on a road given a Y coordinate and road segment.""" + y1 = self.consts.TRACK_CORNERS_Y[road_segment] + y2 = self.consts.TRACK_CORNERS_Y[road_segment + 1] + x1 = track_corners_x[road_segment] + x2 = track_corners_x[road_segment + 1] + + # Linear interpolation: x = x1 + (y - y1) * (x2 - x1) / (y2 - y1) + t = jax.lax.cond( + y2 != y1, + lambda _: (y - y1) / (y2 - y1), + lambda _: 0.0, + operand=None, + ) + return x1 + t * (x2 - x1) \ No newline at end of file diff --git a/src/jaxatari/games/sprites/up_n_down/all_flags_top.npy b/src/jaxatari/games/sprites/up_n_down/all_flags_top.npy new file mode 100644 index 0000000000000000000000000000000000000000..37ea4a9e5e4f527e6880b0df1e441d1fa05bdcc3 GIT binary patch literal 8752 zcmeI%u}T9$5C-70^&Nz>x(E@HBQ`#Ook0)_D~aYLg&4V08%>+m!Z)z8u?*zxG+PV| zKe^kRne{AOA%@+Ve{(D!VIc40i_^2qo#Z)r$)-hhQ)b6$_WCr;`e}AMFCWXv{nfmj z7HU8u+gUmdicUci_kM zzA2CQzB=!j&RgfXDUZ)H&R3@M7K1NS9-rrP@Mb#iFkhSU_&jab&Y8uUa~9H?b7ryT zoQ1UJoLQ_nXCbXQXBKPD*)lr8&bHO1d)%2c>-xDfO#p9V^B6vtbsod>r1Kc&t@9Y3C!NPIZ=J{RJn1}! n`L5Q{6f`YreLG0z%vA3o=SYs4Nwc&t_BlEJ0t*~{Dajn8@D8u@zQvGxg7?LvQ5#6F;&hYs}VOD6KbHL_oH4d{6XkuSZ>G0Pfr UA3o=S3xZ?EAM5?kRM+?a09{(!9smFU literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/balloon.npy b/src/jaxatari/games/sprites/up_n_down/balloon.npy new file mode 100644 index 0000000000000000000000000000000000000000..a43ca6f4652f0266bd9566746452b73e0f0379d8 GIT binary patch literal 640 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1JQ);NLqoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$-W;zNMItnJ5ItsN4WC1P)28RDYj0`}0bTN?p|6@5;U<{Q;o9wdMT;1Y)^MfVp>9GS*v20s1xb literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/balloon_backup.npy b/src/jaxatari/games/sprites/up_n_down/balloon_backup.npy new file mode 100644 index 0000000000000000000000000000000000000000..f987d2d9020b1eccded6f93125da2f01d5055391 GIT binary patch literal 548 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1JQ);NLqoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$-raB7dItnJ5ItsN4WC1P)28RDYj0`}046$Q5R{ud5DhHEC5kry(nFp4` wr;ivpd}gA_q1Xv^CrAt#z}$^224lng2IHgCxb))Ehf5q}E-tfRcB04w08345EC2ui literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/cherry.npy b/src/jaxatari/games/sprites/up_n_down/cherry.npy new file mode 100644 index 0000000000000000000000000000000000000000..db4fd85033b68715abde35b17cc4451fd1a876aa GIT binary patch literal 640 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1JQ);NLqoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$-W;zNMItnJ5ItsN4WC1P)RPY~_i^0Pu50Zmn42@7ZkT^Coko6(6Ve0>% zIB^159N7%898?~z9>jnFkQ_`uj0VcV%s`QY$wO(75g76yIhZ=AIbgTJ#9=hD9WZmy j% zIB^159N7%898?~z9>jnFkQ_`uj0VcV%s`QY$wO(75g76yIhZ=AIbgTJ#9=hD9WZmy jhUtNt0k#ua9ZW4u9!R68f$E336YM6C8Y~!U1YS9i8hm=O$-~Tr(dhbN Y;xHOn9wvrP!_=UQ!}uV1d;qIF04V2!5C8xG literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/ice_cream_cone_backup.npy b/src/jaxatari/games/sprites/up_n_down/ice_cream_cone_backup.npy new file mode 100644 index 0000000000000000000000000000000000000000..f2d3a55667656dc7ffe3f358d4dcb24931971a9c GIT binary patch literal 576 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1JQ);NLqoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$-W;zPyItnJ5ItsN4WC1P)1o#hR!Z1PDB2 zgUmq|hnkC`7G@XJJh&Zb`e1IxMT3mNCI>T~6gik4Y-Zw;Ll%dzVRpj!$TUa}7r>AM E06y4*4*&oF literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/lollypop.npy b/src/jaxatari/games/sprites/up_n_down/lollypop.npy new file mode 100644 index 0000000000000000000000000000000000000000..8b173fb39c5f4dab748861c12c3637b657a718a0 GIT binary patch literal 640 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1JQ);NLqoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$-W;zNMItnJ5ItsN4WC1P)RPZ0jg<%w#Np1)JBV(vKm>Lv0kUDzHgY3iS dKB#+O?t;-U_k(Dd7>q`igRwz!_`r~p2LQ9GcVhqm literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/lollypop_backup.npy b/src/jaxatari/games/sprites/up_n_down/lollypop_backup.npy new file mode 100644 index 0000000000000000000000000000000000000000..10f3b622549b380a8d8a308727a0f896918f270d GIT binary patch literal 512 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1JQ);NLqoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$-W;zN$WTL5~P^&-|;9@|4|3D@TgGDE~9r%xoVe-ggAU3tdLFOUb0W%*5 V4HSmE8|D^}7z}{Kuz^7#4gec9 zcpO6(M>Yd22bG7b2Qgp(BnQ(Eqk(cTGf?DU@=zLN1cp3F4yF!j4%lrlaTtwk2h1Ea iIhY%O(hz&_$-~ruXe10|f&sD^j17~6(J*lk4F&)N;f|pI literal 0 HcmV?d00001 From 0cec9fbe425074b0ff37539c322614b3e423d880 Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Wed, 17 Dec 2025 16:43:57 +0100 Subject: [PATCH 15/76] add score and display and passive point gain --- src/jaxatari/games/jax_upndown.py | 72 ++++++++++++++++++ .../games/sprites/up_n_down/score/score_0.npy | Bin 0 -> 268 bytes .../games/sprites/up_n_down/score/score_1.npy | Bin 0 -> 212 bytes .../games/sprites/up_n_down/score/score_2.npy | Bin 0 -> 268 bytes .../games/sprites/up_n_down/score/score_3.npy | Bin 0 -> 268 bytes .../games/sprites/up_n_down/score/score_4.npy | Bin 0 -> 268 bytes .../games/sprites/up_n_down/score/score_5.npy | Bin 0 -> 268 bytes .../games/sprites/up_n_down/score/score_6.npy | Bin 0 -> 268 bytes .../games/sprites/up_n_down/score/score_7.npy | Bin 0 -> 268 bytes .../games/sprites/up_n_down/score/score_8.npy | Bin 0 -> 268 bytes .../games/sprites/up_n_down/score/score_9.npy | Bin 0 -> 268 bytes 11 files changed, 72 insertions(+) create mode 100644 src/jaxatari/games/sprites/up_n_down/score/score_0.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/score/score_1.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/score/score_2.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/score/score_3.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/score/score_4.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/score/score_5.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/score/score_6.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/score/score_7.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/score/score_8.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/score/score_9.npy diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index e469364bf..a1646fa43 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -114,6 +114,8 @@ class UpNDownState(NamedTuple): is_on_road: chex.Array player_car: Car step_counter: chex.Array + round_started: chex.Array + movement_steps: chex.Array # Flag state - tracks all 8 flags flags: Flag # Contains arrays of size NUM_FLAGS for each field flags_collected_mask: chex.Array # Boolean mask of which flag colors have been collected (size NUM_FLAGS) @@ -621,6 +623,13 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: type=state.player_car.type, ), step_counter=state.step_counter + 1, + round_started=jnp.logical_or(state.round_started, player_speed != 0), + movement_steps=jax.lax.cond( + jnp.logical_or(state.round_started, player_speed != 0), + lambda s: state.movement_steps + 1, + lambda s: state.movement_steps, + operand=None, + ), flags=state.flags, flags_collected_mask=state.flags_collected_mask, collectibles=state.collectibles, @@ -645,6 +654,8 @@ def _flag_step_main(self, state: UpNDownState) -> UpNDownState: is_on_road=state.is_on_road, player_car=state.player_car, step_counter=state.step_counter, + round_started=state.round_started, + movement_steps=state.movement_steps, flags=new_flags, flags_collected_mask=new_flags_collected_mask, collectibles=state.collectibles, @@ -669,12 +680,39 @@ def _collectible_step_main(self, state: UpNDownState) -> UpNDownState: is_on_road=state.is_on_road, player_car=state.player_car, step_counter=state.step_counter, + round_started=state.round_started, + movement_steps=state.movement_steps, flags=state.flags, flags_collected_mask=state.flags_collected_mask, collectibles=updated_collectibles, collectible_spawn_timer=new_collectible_timer, ) + def _passive_score_step_main(self, state: UpNDownState) -> UpNDownState: + """Award passive score every 60 steps after the player has started moving.""" + bonus = jax.lax.cond( + jnp.logical_and(state.round_started, state.movement_steps % 60 == 0), + lambda _: jnp.int32(10), + lambda _: jnp.int32(0), + operand=None, + ) + + return UpNDownState( + score=state.score + bonus, + difficulty=state.difficulty, + jump_cooldown=state.jump_cooldown, + is_jumping=state.is_jumping, + is_on_road=state.is_on_road, + player_car=state.player_car, + step_counter=state.step_counter, + round_started=state.round_started, + movement_steps=state.movement_steps, + flags=state.flags, + flags_collected_mask=state.flags_collected_mask, + collectibles=state.collectibles, + collectible_spawn_timer=state.collectible_spawn_timer, + ) + def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: # Initialize flags at random positions along the track @@ -741,6 +779,8 @@ def get_road_segment(y): type=0, ), step_counter=jnp.array(0), + round_started=jnp.array(False), + movement_steps=jnp.array(0), flags=flags, flags_collected_mask=jnp.zeros(self.consts.NUM_FLAGS, dtype=jnp.bool_), collectibles=collectibles, @@ -753,6 +793,7 @@ def get_road_segment(y): def step(self, state: UpNDownState, action: chex.Array) -> Tuple[UpNDownObservation, UpNDownState, float, bool, UpNDownInfo]: previous_state = state state = self._player_step(state, action) + state = self._passive_score_step_main(state) state = self._flag_step_main(state) state = self._collectible_step_main(state) @@ -880,6 +921,13 @@ def __init__(self, consts: UpNDownConstants = None): self.ice_cream_solid_mask = self.ice_cream_base_mask != self.jr.TRANSPARENT_ID self.ice_cream_palette_ids = self._compute_flag_palette_ids() + # Score rendering helpers + self.score_digit_masks = self.SHAPE_MASKS["score_digits"] + self.score_max_digits = 6 + self.score_digit_spacing = int(self.score_digit_masks.shape[2]) + 1 + self.score_render_y = 6 + self.score_center_x = self.config.game_dimensions[1] // 2 - self.config.game_dimensions[1] // 4 + def _createBackgroundSprite(self, dimensions: Tuple[int, int]) -> jnp.ndarray: """Creates a procedural background sprite for the game.""" height, width = dimensions @@ -929,6 +977,7 @@ def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.n {'name': 'wall_bottom', 'type': 'procedural', 'data': bottomBlockSprite}, {'name': 'all_flags_top', 'type': 'single', 'file': 'all_flags_top.npy'}, {'name': 'all_lives_bottom', 'type': 'single', 'file': 'all_lives_bottom.npy'}, + {'name': 'score_digits', 'type': 'digits', 'pattern': 'score/score_{}.npy'}, {'name': 'pink_flag', 'type': 'single', 'file': 'pink_flag.npy'}, {'name': 'flag_pole', 'type': 'single', 'file': 'flag_pole.npy'}, {'name': 'cherry', 'type': 'single', 'file': 'cherry.npy'}, @@ -1001,6 +1050,29 @@ def combine(i, acc): all_flags_top_mask = self.SHAPE_MASKS["all_flags_top"] raster = self.jr.render_at(raster, 10, 20, all_flags_top_mask) + # Render score centered at the top using dedicated score digit sprites + score_digits = self.jr.int_to_digits(state.score, max_digits=self.score_max_digits) + non_zero_mask = score_digits != 0 + has_non_zero = jnp.any(non_zero_mask) + first_non_zero = jnp.argmax(non_zero_mask) + start_index = jax.lax.select(has_non_zero, first_non_zero, self.score_max_digits - 1) + num_to_render = jax.lax.select(has_non_zero, self.score_max_digits - start_index, 1) + + total_width = num_to_render * self.score_digit_spacing + score_x = self.score_center_x - (total_width // 2) + + raster = self.jr.render_label_selective( + raster, + jnp.int32(score_x), + self.score_render_y, + score_digits, + self.score_digit_masks, + start_index, + num_to_render, + spacing=self.score_digit_spacing, + max_digits_to_render=self.score_max_digits, + ) + # Render flags on the road flag_pole_mask = self.SHAPE_MASKS["flag_pole"] diff --git a/src/jaxatari/games/sprites/up_n_down/score/score_0.npy b/src/jaxatari/games/sprites/up_n_down/score/score_0.npy new file mode 100644 index 0000000000000000000000000000000000000000..1988fab2213d0131f9fc800388d86d05d457642a GIT binary patch literal 268 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1JQ);NLqoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I%|Itr#b3MQI53bhL40WJn0SYgorABJIk5Df#!a?}z7nFBKyM1%MML?=Q& literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/score/score_1.npy b/src/jaxatari/games/sprites/up_n_down/score/score_1.npy new file mode 100644 index 0000000000000000000000000000000000000000..3f847e8500eb0fc9b04da261e1e87682d341582a GIT binary patch literal 212 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1JQ);NLqoAIaUsO_*m=~X4l#&V(cT3DEP6dh= vXCxM+0{I%|Its=*3MQI53bhL40WJn0SYgorA4bF2Fg}O|0Yc&+c^C!&E&emv literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/score/score_2.npy b/src/jaxatari/games/sprites/up_n_down/score/score_2.npy new file mode 100644 index 0000000000000000000000000000000000000000..bd02cd52432ad8246189e966d893924a46641889 GIT binary patch literal 268 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1JQ);NLqoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I%|Itr#b3MQI53bhL40WJn0SYgorABJIk5Dfz$IT#y8Ba0#H1<51JVdKNZ IVESM*0Ncbr*#H0l literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/score/score_3.npy b/src/jaxatari/games/sprites/up_n_down/score/score_3.npy new file mode 100644 index 0000000000000000000000000000000000000000..775039f6e5d3e2e4efe05725d381d6cbd5a2fa33 GIT binary patch literal 268 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1JQ);NLqoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I%|Itr#b3MQI53bhL40j?DW{r}N10}=qqA#oYd`5-x%eq?npHa0n!88A5* F9{@E%KiL2P literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/score/score_4.npy b/src/jaxatari/games/sprites/up_n_down/score/score_4.npy new file mode 100644 index 0000000000000000000000000000000000000000..e3aa799b2b9176f32c5cca44fa71a1c4eb6931cb GIT binary patch literal 268 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1JQ);NLqoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I%|Itr#b3MQI53bhL40WJn4u)?7KKN6P##77o`vC(Oe9%L~P8yTaUiOfgV F3jn5nJ#zp6 literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/score/score_5.npy b/src/jaxatari/games/sprites/up_n_down/score/score_5.npy new file mode 100644 index 0000000000000000000000000000000000000000..398b06f212008f40045f40ba1a0cf9a124e3b56e GIT binary patch literal 268 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1JQ);NLqoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I%|Itr#b3MQI53bhL40j?DW{r|CH1~h=K22C2w1*t(7hv`EWhpB_f!T10? Cazv#7 literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/score/score_6.npy b/src/jaxatari/games/sprites/up_n_down/score/score_6.npy new file mode 100644 index 0000000000000000000000000000000000000000..dede86feab0a5f50f9fb62b99c41e2414c64140e GIT binary patch literal 268 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1JQ);NLqoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I%|Itr#b3MQI53bhL40WJn8SYgorABI6fAT~@68y_YH(*sivqhWes@-RM3 JAB;wp2LQ)lLO%ci literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/score/score_7.npy b/src/jaxatari/games/sprites/up_n_down/score/score_7.npy new file mode 100644 index 0000000000000000000000000000000000000000..cb6269b439cd04963c8ca90e798c9eef08394f7f GIT binary patch literal 268 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1JQ);NLqoAIaUsO_*m=~X4l#&V(cT3DEP6dh= vXCxM+0{I%|Itr#b3MQI53bhL40j?DW{r}N10~!FyBMVW=2bqO#21pD5Oz$>9 literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/score/score_8.npy b/src/jaxatari/games/sprites/up_n_down/score/score_8.npy new file mode 100644 index 0000000000000000000000000000000000000000..a422d51a0785771ecdd2902228a23e58ea18349e GIT binary patch literal 268 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1JQ);NLqoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I%|Itr#b3MQI53bhL40WJn0SYgorABJIk5Df#!a>!yJHcSpg!}u^YFg{EU KCJy7H>jePpj6|gX literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/score/score_9.npy b/src/jaxatari/games/sprites/up_n_down/score/score_9.npy new file mode 100644 index 0000000000000000000000000000000000000000..5a618eedb407c28a4c7464b4cb6e23408149e859 GIT binary patch literal 268 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1JQ);NLqoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I%|Itr#b3MQI53bhL40WJn0SYgorABJIk5Df#!a>!yJHcSpi!{lK!NF2rn N(J($tJ&X^M0|3C^LO%ci literal 0 HcmV?d00001 From 305b5a5c8658e21b702d862b1e8007386bff797f Mon Sep 17 00:00:00 2001 From: shaik05 Date: Sun, 21 Dec 2025 10:07:21 +0100 Subject: [PATCH 16/76] a --- src/jaxatari/games/jax_upndown.py | 166 ++++++++++++++++++++++++------ 1 file changed, 137 insertions(+), 29 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index a1646fa43..5a9987a2c 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -17,7 +17,13 @@ class UpNDownConstants(NamedTuple): DIFFICULTIES: chex.Array = jnp.array([0, 1, 2, 3, 4, 5]) ACTION_REPEAT_PROBS: float = 0.25 MAX_SPEED: int = 4 + INITIAL_LIVES: int = 3 + RESPAWN_Y: int = 0 + RESPAWN_X: int = 30 + RESPAWN_DELAY_FRAMES: int = 60 + WATER_DEATH_PENALTY: int = 0 JUMP_FRAMES: int = 10 + ALL_FLAGS_BONUS: int = 1000 LANDING_ZONE: int = 15 FIRST_ROAD_LENGTH: int = 4 SECOND_ROAD_LENGTH: int = 4 @@ -113,6 +119,9 @@ class UpNDownState(NamedTuple): is_jumping: chex.Array is_on_road: chex.Array player_car: Car + lives: chex.Array + is_dead: chex.Array + respawn_timer: chex.Array step_counter: chex.Array round_started: chex.Array movement_steps: chex.Array @@ -421,11 +430,63 @@ def get_collection_score(idx): ) return updated_collectibles, score_delta, new_collectible_timer + def _death_step(self, state: UpNDownState) -> UpNDownState: + # Player on water road (index 2 assumed water) + died = jnp.logical_and( + state.player_car.current_road == 2, + ~state.is_dead, + ) + + lives = jax.lax.cond( + died, + lambda _: state.lives - 1, + lambda _: state.lives, + operand=None, + ) + lives = jax.lax.cond( + died, + lambda _: state.lives - 1, + lambda _: state.lives, + operand=None, + ) + respawn_timer = jax.lax.cond( + died, + lambda _: jnp.array(self.consts.RESPAWN_DELAY_FRAMES), + lambda _: jnp.maximum(state.respawn_timer - 1, 0), + operand=None, + ) + is_dead = jnp.logical_and( + jnp.logical_or(state.is_dead, died), + respawn_timer > 0) + + player_car = jax.lax.cond( + jnp.logical_and(state.is_dead, respawn_timer == 0), + lambda _: state.player_car._replace( + position=state.player_car.position._replace( + x=jnp.array(self.consts.RESPAWN_X, dtype=jnp.float32), + y=jnp.array(self.consts.RESPAWN_Y, dtype=jnp.float32), + ), + speed=0, + current_road=0, + ), + lambda _: state.player_car, + operand=None, + ) + return state._replace( + lives=lives, + is_dead=is_dead, + respawn_timer=respawn_timer, + player_car=player_car, + ) + def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: up = jnp.logical_or(action == Action.UP, action == Action.UPFIRE) down = jnp.logical_or(action == Action.DOWN, action == Action.DOWNFIRE) jump = jnp.logical_or(action == Action.FIRE, jnp.logical_or(action == Action.UPFIRE, action == Action.DOWNFIRE)) + lives=jnp.array(self.consts.INITIAL_LIVES, dtype=jnp.int32), + is_dead=jnp.array(False), + respawn_timer=jnp.array(0, dtype=jnp.int32), @@ -622,6 +683,9 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: road_index_B=road_index_B, type=state.player_car.type, ), + lives=state.lives, + is_dead=state.is_dead, + respawn_timer=state.respawn_timer, step_counter=state.step_counter + 1, round_started=jnp.logical_or(state.round_started, player_speed != 0), movement_steps=jax.lax.cond( @@ -653,6 +717,9 @@ def _flag_step_main(self, state: UpNDownState) -> UpNDownState: is_jumping=state.is_jumping, is_on_road=state.is_on_road, player_car=state.player_car, + lives=state.lives, + is_dead=state.is_dead, + respawn_timer=state.respawn_timer, step_counter=state.step_counter, round_started=state.round_started, movement_steps=state.movement_steps, @@ -661,6 +728,18 @@ def _flag_step_main(self, state: UpNDownState) -> UpNDownState: collectibles=state.collectibles, collectible_spawn_timer=state.collectible_spawn_timer, ) + def _completion_bonus_step(self, state: UpNDownState) -> UpNDownState: + all_flags_collected = jnp.all(state.flags_collected_mask) + + bonus = jax.lax.cond( + all_flags_collected, + lambda _: self.consts.ALL_FLAGS_BONUS, + lambda _: 0, + operand=None, + ) + return state._replace(score=state.score + bonus,lives=state.lives, + is_dead=state.is_dead, + respawn_timer=state.respawn_timer,) def _collectible_step_main(self, state: UpNDownState) -> UpNDownState: """Update collectible spawning, despawning, and collection.""" @@ -679,6 +758,9 @@ def _collectible_step_main(self, state: UpNDownState) -> UpNDownState: is_jumping=state.is_jumping, is_on_road=state.is_on_road, player_car=state.player_car, + lives=state.lives, + is_dead=state.is_dead, + respawn_timer=state.respawn_timer, step_counter=state.step_counter, round_started=state.round_started, movement_steps=state.movement_steps, @@ -704,6 +786,9 @@ def _passive_score_step_main(self, state: UpNDownState) -> UpNDownState: is_jumping=state.is_jumping, is_on_road=state.is_on_road, player_car=state.player_car, + lives=state.lives, + is_dead=state.is_dead, + respawn_timer=state.respawn_timer, step_counter=state.step_counter, round_started=state.round_started, movement_steps=state.movement_steps, @@ -729,6 +814,9 @@ def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: # Alternate roads 0/1 for variety flag_roads = jnp.arange(self.consts.NUM_FLAGS) % 2 + + + # Calculate which road segment each flag is on based on Y position def get_road_segment(y): # Find the segment where TRACK_CORNERS_Y[i] > y >= TRACK_CORNERS_Y[i+1] @@ -757,35 +845,45 @@ def get_road_segment(y): type_id=jnp.zeros(self.consts.MAX_COLLECTIBLES, dtype=jnp.int32), active=jnp.zeros(self.consts.MAX_COLLECTIBLES, dtype=jnp.bool_), ) - + player_car = Car( + position=EntityPosition( + x=jnp.array(30, dtype=jnp.int32), + y=jnp.array(0, dtype=jnp.int32), + width=jnp.array(self.consts.PLAYER_SIZE[0], dtype=jnp.int32), + height=jnp.array(self.consts.PLAYER_SIZE[1], dtype=jnp.int32), + ), + speed=jnp.array(0, dtype=jnp.int32), + direction_x=jnp.array(0, dtype=jnp.int32), + current_road=jnp.array(0, dtype=jnp.int32), + road_index_A=jnp.array(0, dtype=jnp.int32), + road_index_B=jnp.array(0, dtype=jnp.int32), + type=jnp.array(0, dtype=jnp.int32), + ) state = UpNDownState( - score=0, - difficulty=self.consts.DIFFICULTIES[0], - jump_cooldown=0, - is_jumping=False, - is_on_road=True, - player_car=Car( - position=EntityPosition( - x=30, - y= 0, - width=self.consts.PLAYER_SIZE[0], - height=self.consts.PLAYER_SIZE[1], - ), - speed=0, - direction_x=0, - current_road=0, - road_index_A=0, - road_index_B=0, - type=0, - ), - step_counter=jnp.array(0), - round_started=jnp.array(False), - movement_steps=jnp.array(0), - flags=flags, - flags_collected_mask=jnp.zeros(self.consts.NUM_FLAGS, dtype=jnp.bool_), - collectibles=collectibles, - collectible_spawn_timer=jnp.array(0, dtype=jnp.int32), - ) + score=jnp.array(0, dtype=jnp.int32), + difficulty=jnp.array(self.consts.DIFFICULTIES[0], dtype=jnp.int32), + jump_cooldown=jnp.array(0, dtype=jnp.int32), + is_jumping=jnp.array(False), + is_on_road=jnp.array(True), + + player_car=player_car, + + step_counter=jnp.array(0, dtype=jnp.int32), + round_started=jnp.array(False), + movement_steps=jnp.array(0, dtype=jnp.int32), + + flags=flags, + flags_collected_mask=jnp.zeros(self.consts.NUM_FLAGS, dtype=jnp.bool_), + collectibles=collectibles, + collectible_spawn_timer=jnp.array(0, dtype=jnp.int32), + + # -------- NEW REQUIRED FIELDS -------- + lives=jnp.array(self.consts.INITIAL_LIVES, dtype=jnp.int32), + is_dead=jnp.array(False), + respawn_timer=jnp.array(0, dtype=jnp.int32), + ) + + initial_obs = self._get_observation(state) return initial_obs, state @@ -793,9 +891,14 @@ def get_road_segment(y): def step(self, state: UpNDownState, action: chex.Array) -> Tuple[UpNDownObservation, UpNDownState, float, bool, UpNDownInfo]: previous_state = state state = self._player_step(state, action) + state = self._death_step(state) + state = self._passive_score_step_main(state) state = self._flag_step_main(state) + state = self._completion_bonus_step(state) state = self._collectible_step_main(state) + + done = self._get_done(state) env_reward = self._get_reward(previous_state, state) @@ -860,7 +963,12 @@ def _get_reward(self, previous_state: UpNDownState, state: UpNDownState): @partial(jax.jit, static_argnums=(0,)) def _get_done(self, state: UpNDownState) -> bool: - return jnp.logical_not(True) + return jnp.logical_or( + state.lives <= 0, + jnp.all(state.flags_collected_mask), +) + + class UpNDownRenderer(JAXGameRenderer): def __init__(self, consts: UpNDownConstants = None): From 55dd4419934a5ad6bf565a0e052d23173d1fd916 Mon Sep 17 00:00:00 2001 From: shaik05 Date: Sun, 21 Dec 2025 10:21:32 +0100 Subject: [PATCH 17/76] movement --- src/jaxatari/games/upndown_interface.py | 53 ------------------------- 1 file changed, 53 deletions(-) delete mode 100644 src/jaxatari/games/upndown_interface.py diff --git a/src/jaxatari/games/upndown_interface.py b/src/jaxatari/games/upndown_interface.py deleted file mode 100644 index 68f8c76fe..000000000 --- a/src/jaxatari/games/upndown_interface.py +++ /dev/null @@ -1,53 +0,0 @@ -import jax -import jax.numpy as jnp -import matplotlib.pyplot as plt - -from jaxatari.environment import JAXAtariAction as Action -from upndown import JaxUpNDown, UpNDownConstants # <-- your game file - -def visualize_frame(frame: jnp.ndarray): - """Render an RGB frame using matplotlib.""" - plt.imshow(frame.astype(jnp.uint8)) - plt.axis("off") - plt.show(block=False) - plt.pause(0.05) - plt.clf() - - -def main(): - # Initialize environment - env = JaxUpNDown(UpNDownConstants()) - - # Reset environment - obs, state = env.reset() - print("Initial observation:", obs) - - # Display initial render - frame = env.render(state) - visualize_frame(frame) - - # Create a random key for sampling actions - key = jax.random.PRNGKey(0) - - # Run for 50 steps - for step in range(50): - key, subkey = jax.random.split(key) - # Choose a random action from action space - action = jax.random.choice(subkey, jnp.arange(len(env.action_set))) - - obs, state, reward, done, info = env.step(state, action) - - # Render and display - frame = env.render(state) - visualize_frame(frame) - - print(f"Step {step}: action={env.action_set[int(action)]}, reward={reward}, done={done}") - - if bool(done): - print("Game over — resetting environment.") - obs, state = env.reset() - - plt.close() - -if __name__ == "__main__": - main() From ba94abfd1d6c6b813485f3295ca29b670d434d16 Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Thu, 18 Dec 2025 17:34:33 +0100 Subject: [PATCH 18/76] add enemy cars & enemy logic to game --- src/jaxatari/games/jax_upndown.py | 796 ++++++++++++------ .../up_n_down/enemy_cars/camero_left.npy | Bin 0 -> 1024 bytes .../up_n_down/enemy_cars/camero_right.npy | Bin 0 -> 1024 bytes .../enemy_cars/flag_carrier_left.npy | Bin 0 -> 1152 bytes .../enemy_cars/flag_carrier_right.npy | Bin 0 -> 1152 bytes .../enemy_cars/pick_up_truck_left.npy | Bin 0 -> 1152 bytes .../enemy_cars/pick_up_truck_right.npy | Bin 0 -> 1152 bytes .../up_n_down/enemy_cars/truck_left.npy | Bin 0 -> 1024 bytes .../up_n_down/enemy_cars/truck_right.npy | Bin 0 -> 1024 bytes 9 files changed, 559 insertions(+), 237 deletions(-) create mode 100644 src/jaxatari/games/sprites/up_n_down/enemy_cars/camero_left.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/enemy_cars/camero_right.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/enemy_cars/flag_carrier_left.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/enemy_cars/flag_carrier_right.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/enemy_cars/pick_up_truck_left.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/enemy_cars/pick_up_truck_right.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/enemy_cars/truck_left.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/enemy_cars/truck_right.npy diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 5a9987a2c..648859a53 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -16,12 +16,24 @@ class UpNDownConstants(NamedTuple): FRAME_SKIP: int = 4 DIFFICULTIES: chex.Array = jnp.array([0, 1, 2, 3, 4, 5]) ACTION_REPEAT_PROBS: float = 0.25 - MAX_SPEED: int = 4 - INITIAL_LIVES: int = 3 - RESPAWN_Y: int = 0 - RESPAWN_X: int = 30 - RESPAWN_DELAY_FRAMES: int = 60 - WATER_DEATH_PENALTY: int = 0 + MAX_SPEED: int = 5 + # Enemy spawning and movement + MAX_ENEMY_CARS: int = 6 + ENEMY_SPAWN_INTERVAL: int = 80 + ENEMY_DESPAWN_DISTANCE: int = 300 + ENEMY_SPEED_MIN: int = 2 + ENEMY_SPEED_MAX: int = 5 + ENEMY_DIRECTION_SWITCH_PROB: float = 0.005 + ENEMY_OFFSCREEN_SPAWN_OFFSET: float = 100.0 + ENEMY_MIN_SPAWN_GAP: float = 40.0 + ENEMY_MAX_AGE: int = 900 + INITIAL_ENEMY_COUNT: int = 3 + INITIAL_ENEMY_BASE_OFFSET: float = 40.0 + INITIAL_ENEMY_GAP: float = 50.0 + ENEMY_TYPE_CAMERO: int = 0 + ENEMY_TYPE_FLAG_CARRIER: int = 1 + ENEMY_TYPE_PICKUP: int = 2 + ENEMY_TYPE_TRUCK: int = 3 JUMP_FRAMES: int = 10 ALL_FLAGS_BONUS: int = 1000 LANDING_ZONE: int = 15 @@ -112,6 +124,19 @@ class Collectible(NamedTuple): type_id: chex.Array # Type of collectible (0=cherry, 1=balloon, 2=lollypop, 3=ice_cream) active: chex.Array # Whether this collectible slot is active (spawned) + +class EnemyCars(NamedTuple): + """Pool of enemy cars that share the same road-following logic as the player.""" + position: EntityPosition # vectorized position fields, size MAX_ENEMY_CARS + speed: chex.Array # signed speed per car + type: chex.Array # type id per car + current_road: chex.Array + road_index_A: chex.Array + road_index_B: chex.Array + direction_x: chex.Array + active: chex.Array + age: chex.Array + class UpNDownState(NamedTuple): score: chex.Array difficulty: chex.Array @@ -131,6 +156,9 @@ class UpNDownState(NamedTuple): # Collectible state - dynamic spawning (mixed types: cherry, balloon, lollypop, ice cream) collectibles: Collectible # Contains arrays of size MAX_COLLECTIBLES for each field collectible_spawn_timer: chex.Array # Counter for collectible spawn timing + # Enemy cars - dynamic spawning and movement + enemy_cars: EnemyCars + enemy_spawn_timer: chex.Array @@ -161,20 +189,39 @@ def __init__(self, consts: UpNDownConstants = None, reward_funcs: list[callable] self.obs_size = 3*4+1+1 @partial(jax.jit, static_argnums=(0,)) - def _getSlopeAndB(self, state: UpNDownState) -> chex.Array: - trackx, tracky, roadIndex = jax.lax.cond( - state.player_car.current_road == 0, - lambda s: (self.consts.FIRST_TRACK_CORNERS_X, self.consts.TRACK_CORNERS_Y, state.player_car.road_index_A), - lambda s: (self.consts.SECOND_TRACK_CORNERS_X, self.consts.TRACK_CORNERS_Y, state.player_car.road_index_B), - operand=None,) + def _get_slope_and_intercept_from_indices(self, current_road: chex.Array, road_index_A: chex.Array, road_index_B: chex.Array) -> Tuple[chex.Array, chex.Array]: + trackx, tracky, road_index = jax.lax.cond( + current_road == 0, + lambda _: (self.consts.FIRST_TRACK_CORNERS_X, self.consts.TRACK_CORNERS_Y, road_index_A), + lambda _: (self.consts.SECOND_TRACK_CORNERS_X, self.consts.TRACK_CORNERS_Y, road_index_B), + operand=None, + ) slope = jax.lax.cond( - trackx[roadIndex+1] - trackx[roadIndex] != 0, - lambda s: (tracky[roadIndex+1] - tracky[roadIndex]) / (trackx[roadIndex+1] - trackx[roadIndex]), - lambda s: 300.0, + trackx[road_index+1] - trackx[road_index] != 0, + lambda _: (tracky[road_index+1] - tracky[road_index]) / (trackx[road_index+1] - trackx[road_index]), + lambda _: 300.0, operand=None, ) - b = tracky[roadIndex] - slope * trackx[roadIndex] + b = tracky[road_index] - slope * trackx[road_index] return slope, b + + @partial(jax.jit, static_argnums=(0,)) + def _getSlopeAndB(self, state: UpNDownState) -> chex.Array: + return self._get_slope_and_intercept_from_indices( + state.player_car.current_road, + state.player_car.road_index_A, + state.player_car.road_index_B, + ) + + @partial(jax.jit, static_argnums=(0,)) + def _is_on_line_for_position(self, position: EntityPosition, slope: chex.Array, b: chex.Array, player_speed: chex.Array, turn: chex.Array) -> chex.Array: + x_step = abs(jnp.subtract(position.y, slope * (position.x) + b)) + y_step = abs(jnp.subtract(position.y - player_speed, slope * position.x + b)) + prefer_y = jnp.less_equal(y_step, x_step) + return jnp.logical_or( + jnp.logical_and(turn == 1, prefer_y), + jnp.logical_and(turn == 2, jnp.logical_not(prefer_y)), + ) @partial(jax.jit, static_argnums=(0,)) def _get_x_on_road(self, y: chex.Array, road_segment: chex.Array, track_corners_x: chex.Array) -> chex.Array: @@ -196,13 +243,7 @@ def _get_x_on_road(self, y: chex.Array, road_segment: chex.Array, track_corners_ @partial(jax.jit, static_argnums=(0,)) def _isOnLine(self, state: UpNDownState, player_speed: chex.Array, turn: chex.Array) -> chex.Array: slope, b = self._getSlopeAndB(state) - x_step = abs(jnp.subtract(state.player_car.position.y, slope * (state.player_car.position.x) + b)) - y_step = abs(jnp.subtract(state.player_car.position.y - player_speed, slope * state.player_car.position.x + b)) - prefer_y = jnp.less_equal(y_step, x_step) - return jnp.logical_or( - jnp.logical_and(turn == 1, prefer_y), - jnp.logical_and(turn == 2, jnp.logical_not(prefer_y)), - ) + return self._is_on_line_for_position(state.player_car.position, slope, b, player_speed, turn) @partial(jax.jit, static_argnums=(0,)) def _landing_in_water(self, state: UpNDownState, new_position_x: chex.Array, new_position_y: chex.Array) -> chex.Array: @@ -214,6 +255,181 @@ def _landing_in_water(self, state: UpNDownState, new_position_x: chex.Array, new between_roads = jnp.logical_and(new_position_x > jnp.minimum(road_A_x, road_B_x), new_position_x < jnp.maximum(road_A_x, road_B_x)) return landing_in_Water, between_roads, road_A_x, road_B_x + @partial(jax.jit, static_argnums=(0,)) + def _landing_in_water_for_indices(self, road_index_A: chex.Array, road_index_B: chex.Array, new_position_x: chex.Array, new_position_y: chex.Array) -> chex.Array: + road_A_x = ((new_position_y - self.consts.TRACK_CORNERS_Y[road_index_A]) / (self.consts.TRACK_CORNERS_Y[road_index_A+1] - self.consts.TRACK_CORNERS_Y[road_index_A])) * (self.consts.FIRST_TRACK_CORNERS_X[road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_X[road_index_A]) + self.consts.FIRST_TRACK_CORNERS_X[road_index_A] + road_B_x = ((new_position_y - self.consts.TRACK_CORNERS_Y[road_index_B]) / (self.consts.TRACK_CORNERS_Y[road_index_B+1] - self.consts.TRACK_CORNERS_Y[road_index_B])) * (self.consts.SECOND_TRACK_CORNERS_X[road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[road_index_B]) + self.consts.SECOND_TRACK_CORNERS_X[road_index_B] + distance_to_road_A = jnp.abs(new_position_x - road_A_x) + distance_to_road_B = jnp.abs(new_position_x - road_B_x) + landing_in_water = jnp.logical_and(distance_to_road_A > self.consts.LANDING_ZONE, distance_to_road_B > self.consts.LANDING_ZONE) + between_roads = jnp.logical_and(new_position_x > jnp.minimum(road_A_x, road_B_x), new_position_x < jnp.maximum(road_A_x, road_B_x)) + return landing_in_water, between_roads, road_A_x, road_B_x + + @partial(jax.jit, static_argnums=(0,)) + def _advance_car_core( + self, + position_x: chex.Array, + position_y: chex.Array, + road_index_A: chex.Array, + road_index_B: chex.Array, + current_road: chex.Array, + speed: chex.Array, + is_jumping: chex.Array, + is_on_road: chex.Array, + step_counter: chex.Array, + width: chex.Array, + height: chex.Array, + car_type: chex.Array, + landing_check: chex.Array, + ) -> Car: + dividers = jnp.array([0, 1, 2, 4, 8, 16]) + abs_speed = jnp.abs(speed) + speed_divider = dividers[abs_speed] + effective_divider = jnp.maximum(1, speed_divider) + period = jnp.maximum(1, 16 // effective_divider) + half_period = jnp.maximum(1, period // 2) + speed_sign = jax.lax.cond(speed != 0, lambda _: jax.lax.abs(speed) / speed, lambda _: jnp.array(0.0), operand=None) + + slope, b = self._get_slope_and_intercept_from_indices(current_road, road_index_A, road_index_B) + + direction_raw = jax.lax.cond( + current_road == 0, + lambda _: self.consts.FIRST_TRACK_CORNERS_X[road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_X[road_index_A], + lambda _: self.consts.SECOND_TRACK_CORNERS_X[road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[road_index_B], + operand=None, + ) + car_direction_x = jax.lax.cond(direction_raw > 0, lambda _: 1, lambda _: -1, operand=None) + + move_y = jnp.logical_and((step_counter % period) == (half_period % period), speed != 0) + move_x = jnp.logical_and((step_counter % period) == 0, speed != 0) + + position = EntityPosition(x=position_x, y=position_y, width=width, height=height) + + new_player_y = jax.lax.cond( + move_y, + lambda _: jax.lax.cond( + is_jumping, + lambda _: position_y + speed_sign * -1, + lambda _: jax.lax.cond( + self._is_on_line_for_position(position, slope, b, speed_sign, 1), + lambda _: position_y + speed_sign * -1, + lambda _: jnp.array(position_y, float), + operand=None, + ), + operand=None, + ), + lambda _: jnp.array(position_y, float), + operand=None, + ) + + new_player_x = jax.lax.cond( + move_x, + lambda _: jax.lax.cond( + is_jumping, + lambda _: position_x + speed_sign * car_direction_x, + lambda _: jax.lax.cond( + self._is_on_line_for_position(position, slope, b, speed_sign, 2), + lambda _: position_x + speed_sign * car_direction_x, + lambda _: jnp.array(position_x, float), + operand=None, + ), + operand=None, + ), + lambda _: jnp.array(position_x, float), + operand=None, + ) + + landing_in_water, between_roads, road_A_x, road_B_x = self._landing_in_water_for_indices(road_index_A, road_index_B, new_player_x, new_player_y) + landing_in_water = jnp.logical_and(landing_check, landing_in_water) + + updated_current_road = jax.lax.cond( + landing_in_water, + lambda _: 2, + lambda _: jax.lax.cond( + is_on_road, + lambda _: current_road, + lambda _: jax.lax.cond( + jnp.abs(new_player_x - road_A_x) < jnp.abs(new_player_x - road_B_x), + lambda _: 0, + lambda _: 1, + operand=None, + ), + operand=None, + ), + operand=None, + ) + + next_road_index_A = jax.lax.cond( + updated_current_road == 2, + lambda _: road_index_A, + lambda _: jax.lax.cond( + self.consts.TRACK_CORNERS_Y[road_index_A] < new_player_y, + lambda _: road_index_A - 1, + lambda _: jax.lax.cond( + len(self.consts.TRACK_CORNERS_Y) == road_index_A + 1, + lambda _: jax.lax.cond( + self.consts.TRACK_CORNERS_Y[0] > new_player_y, + lambda _: 0, + lambda _: road_index_A, + operand=None, + ), + lambda _: jax.lax.cond( + self.consts.TRACK_CORNERS_Y[road_index_A+1] > new_player_y, + lambda _: road_index_A + 1, + lambda _: road_index_A, + operand=None, + ), + operand=None, + ), + operand=None, + ), + operand=None, + ) + + next_road_index_B = jax.lax.cond( + updated_current_road == 2, + lambda _: road_index_B, + lambda _: jax.lax.cond( + self.consts.TRACK_CORNERS_Y[road_index_B] < new_player_y, + lambda _: road_index_B - 1, + lambda _: jax.lax.cond( + len(self.consts.TRACK_CORNERS_Y) == road_index_B + 1, + lambda _: jax.lax.cond( + self.consts.TRACK_CORNERS_Y[0] > new_player_y, + lambda _: 0, + lambda _: road_index_B, + operand=None, + ), + lambda _: jax.lax.cond( + self.consts.TRACK_CORNERS_Y[road_index_B+1] > new_player_y, + lambda _: road_index_B + 1, + lambda _: road_index_B, + operand=None, + ), + operand=None, + ), + operand=None, + ), + operand=None, + ) + + wrapped_y = -((new_player_y * -1) % 1036) + + return Car( + position=EntityPosition( + x=new_player_x, + y=wrapped_y, + width=width, + height=height, + ), + speed=speed, + direction_x=car_direction_x, + current_road=updated_current_road, + road_index_A=next_road_index_A, + road_index_B=next_road_index_B, + type=car_type, + ) + def _flag_step(self, state: UpNDownState, new_player_y: chex.Array, player_x: chex.Array, current_road: chex.Array) -> Tuple[Flag, chex.Array, chex.Array]: """Update flag collection state and score. @@ -484,12 +700,6 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: up = jnp.logical_or(action == Action.UP, action == Action.UPFIRE) down = jnp.logical_or(action == Action.DOWN, action == Action.DOWNFIRE) jump = jnp.logical_or(action == Action.FIRE, jnp.logical_or(action == Action.UPFIRE, action == Action.DOWNFIRE)) - lives=jnp.array(self.consts.INITIAL_LIVES, dtype=jnp.int32), - is_dead=jnp.array(False), - respawn_timer=jnp.array(0, dtype=jnp.int32), - - - player_speed = state.player_car.speed player_speed = jax.lax.cond( @@ -505,9 +715,6 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: lambda s: s, operand=player_speed, ) - dividers = jnp.array([0, 1, 2, 4, 8]) - speed_divider = dividers[jnp.abs(player_speed)] - is_jumping = jnp.logical_or(jnp.logical_and(state.is_jumping, state.jump_cooldown > 0), jnp.logical_and(state.is_on_road, jnp.logical_and(player_speed >= 0, jnp.logical_and(state.jump_cooldown == 0, jump)))) jump_cooldown = jax.lax.cond( @@ -519,173 +726,32 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: operand=None), operand=state.jump_cooldown, ) - - - - - ##check if player is on the the road - is_on_road = ~state.is_jumping - - '''direction_change = jax.lax.cond( - jnp.logical_and(is_on_road, jnp.logical_or(jnp.logical_and(jnp.equal(road_index_A, state.player_car.road_index_A) , state.player_car.current_road == 0), (jnp.logical_and(jnp.equal(road_index_B, state.player_car.road_index_B) , state.player_car.current_road == 1)))) , - lambda s: False, - lambda s: True, - operand=None, - )''' - road_index_A = state.player_car.road_index_A - road_index_B = state.player_car.road_index_B - - car_direction_x = jax.lax.cond(state.player_car.current_road == 0, - lambda s: self.consts.FIRST_TRACK_CORNERS_X[road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_X[road_index_A], - lambda s: self.consts.SECOND_TRACK_CORNERS_X[road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[road_index_B], - operand=None), - car_direction_x = jax.lax.cond( - car_direction_x[0] > 0, - lambda s: 1, - lambda s: -1, - operand=car_direction_x, - ) - + is_on_road = ~is_jumping is_landing = jnp.logical_and(state.jump_cooldown == 1, jump_cooldown == 0) - ##calculate new position with speed (TODO: calculate better speed) - player_y = jax.lax.cond( - jnp.logical_and((state.step_counter % (16/ speed_divider) == 8 / speed_divider), player_speed != 0,), - lambda s: jax.lax.cond( - is_jumping, - lambda s: state.player_car.position.y + jax.lax.abs(player_speed) / player_speed * -1, - lambda s: jax.lax.cond( - self._isOnLine(state, jax.lax.abs(player_speed) / player_speed, 1), - lambda s: s + jax.lax.abs(player_speed) / player_speed * -1, - lambda s: jnp.array(s, float), - operand=state.player_car.position.y, - ), - operand=state.player_car.position.y), - lambda s: jnp.array(s, float), - operand=state.player_car.position.y, - ) - player_x = jax.lax.cond( - jnp.logical_and((state.step_counter % (16/ speed_divider) == 0), player_speed != 0,), - lambda s: jax.lax.cond( - is_jumping, - lambda s: s + jax.lax.abs(player_speed) / player_speed * car_direction_x, - lambda s: jax.lax.cond( - self._isOnLine(state, jax.lax.abs(player_speed) / player_speed, 2), - lambda s: s + jax.lax.abs(player_speed) / player_speed * car_direction_x, - lambda s: jnp.array(s, float), - operand=state.player_car.position.x, - ), - operand=state.player_car.position.x), - lambda s: jnp.array(s, float), - operand=state.player_car.position.x, - ) - - ##if y not on mx +b then no move - - - landing_in_Water, between_roads, road_A_x, road_B_x = self._landing_in_water(state, player_x, player_y) - landing_in_Water = jnp.logical_and(is_landing, landing_in_Water) - - - current_road = jax.lax.cond( - landing_in_Water, - lambda s: 2, - lambda s: jax.lax.cond( - is_on_road, - lambda s: state.player_car.current_road, - lambda s: jax.lax.cond( - jnp.abs(player_x - road_A_x) < jnp.abs(player_x - road_B_x), - lambda s: 0, - lambda s: 1, - operand=None, - ), - operand=None, - ), - operand=None, - ) - - road_index_A = jax.lax.cond( - current_road == 2, - lambda s: road_index_A, - lambda s: jax.lax.cond( - self.consts.TRACK_CORNERS_Y[road_index_A] < player_y, - lambda s: road_index_A - 1, - lambda s: jax.lax.cond( - len(self.consts.TRACK_CORNERS_Y) == road_index_A + 1, - lambda s: jax.lax.cond( - self.consts.TRACK_CORNERS_Y[0] > player_y, - lambda s: 0, - lambda s: road_index_A, - operand=None, - ), - lambda s: jax.lax.cond( - self.consts.TRACK_CORNERS_Y[road_index_A+1] > player_y, - lambda s: road_index_A + 1, - lambda s: road_index_A, - operand=None, - ), - operand=None, - ), - operand=None, - ), - operand=None, - ) - - road_index_B = jax.lax.cond( - current_road == 2, - lambda s: road_index_B, - lambda s: jax.lax.cond( - self.consts.TRACK_CORNERS_Y[road_index_B] < player_y, - lambda s: road_index_B - 1, - lambda s: jax.lax.cond( - len(self.consts.TRACK_CORNERS_Y) == road_index_B + 1, - lambda s: jax.lax.cond( - self.consts.TRACK_CORNERS_Y[0] > player_y, - lambda s: 0, - lambda s: road_index_B, - operand=None, - ), - lambda s: jax.lax.cond( - self.consts.TRACK_CORNERS_Y[road_index_B+1] > player_y, - lambda s: road_index_B + 1, - lambda s: road_index_B, - operand=None, - ), - operand=None, - ), - operand=None, - ), - operand=None, + updated_player_car = self._advance_car_core( + position_x=state.player_car.position.x, + position_y=state.player_car.position.y, + road_index_A=state.player_car.road_index_A, + road_index_B=state.player_car.road_index_B, + current_road=state.player_car.current_road, + speed=player_speed, + is_jumping=is_jumping, + is_on_road=is_on_road, + step_counter=state.step_counter, + width=state.player_car.position.width, + height=state.player_car.position.height, + car_type=state.player_car.type, + landing_check=is_landing, ) - jax.debug.print("Player X: {}, Player Y: {}, car_direction_x: {}", player_x, player_y, car_direction_x) - - # Calculate new player y position after wrapping - new_player_y = -((player_y * -1) % 1036) - return UpNDownState( score=state.score, difficulty=state.difficulty, jump_cooldown=jump_cooldown, is_jumping=is_jumping, is_on_road=is_on_road, - player_car=Car( - position=EntityPosition( - x=player_x, - y=new_player_y, - width=state.player_car.position.width, - height=state.player_car.position.height, - ), - speed=player_speed, - direction_x=car_direction_x, - current_road=current_road, - road_index_A=road_index_A, - road_index_B=road_index_B, - type=state.player_car.type, - ), - lives=state.lives, - is_dead=state.is_dead, - respawn_timer=state.respawn_timer, + player_car=updated_player_car, step_counter=state.step_counter + 1, round_started=jnp.logical_or(state.round_started, player_speed != 0), movement_steps=jax.lax.cond( @@ -698,6 +764,8 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: flags_collected_mask=state.flags_collected_mask, collectibles=state.collectibles, collectible_spawn_timer=state.collectible_spawn_timer, + enemy_cars=state.enemy_cars, + enemy_spawn_timer=state.enemy_spawn_timer, ) def _flag_step_main(self, state: UpNDownState) -> UpNDownState: @@ -727,19 +795,9 @@ def _flag_step_main(self, state: UpNDownState) -> UpNDownState: flags_collected_mask=new_flags_collected_mask, collectibles=state.collectibles, collectible_spawn_timer=state.collectible_spawn_timer, + enemy_cars=state.enemy_cars, + enemy_spawn_timer=state.enemy_spawn_timer, ) - def _completion_bonus_step(self, state: UpNDownState) -> UpNDownState: - all_flags_collected = jnp.all(state.flags_collected_mask) - - bonus = jax.lax.cond( - all_flags_collected, - lambda _: self.consts.ALL_FLAGS_BONUS, - lambda _: 0, - operand=None, - ) - return state._replace(score=state.score + bonus,lives=state.lives, - is_dead=state.is_dead, - respawn_timer=state.respawn_timer,) def _collectible_step_main(self, state: UpNDownState) -> UpNDownState: """Update collectible spawning, despawning, and collection.""" @@ -768,6 +826,159 @@ def _collectible_step_main(self, state: UpNDownState) -> UpNDownState: flags_collected_mask=state.flags_collected_mask, collectibles=updated_collectibles, collectible_spawn_timer=new_collectible_timer, + enemy_cars=state.enemy_cars, + enemy_spawn_timer=state.enemy_spawn_timer, + ) + + @partial(jax.jit, static_argnums=(0,)) + def _enemy_step_main(self, state: UpNDownState) -> UpNDownState: + """Spawn and move enemy cars that share the player's road logic.""" + base_key = jax.random.PRNGKey(2025) + step_key = jax.random.fold_in(base_key, state.step_counter) + key_spawn_offset, key_spawn_side, key_spawn_speed, key_spawn_direction, key_spawn_type, key_spawn_sign, key_flip_root = jax.random.split(step_key, 7) + + active_mask = state.enemy_cars.active + active_count = jnp.sum(active_mask.astype(jnp.int32)) + can_spawn = active_count < self.consts.MAX_ENEMY_CARS + + spawn_timer = jax.lax.cond( + jnp.logical_and(state.enemy_spawn_timer <= 0, can_spawn), + lambda _: self.consts.ENEMY_SPAWN_INTERVAL, + lambda _: state.enemy_spawn_timer - 1, + operand=None, + ) + should_spawn = jnp.logical_and(state.enemy_spawn_timer <= 0, can_spawn) + + inactive_mask = jnp.logical_not(active_mask) + first_inactive = jnp.argmax(inactive_mask.astype(jnp.int32)) + has_inactive = jnp.any(inactive_mask) + spawn_idx = jax.lax.cond(has_inactive, lambda _: first_inactive, lambda _: jnp.array(0, dtype=jnp.int32), operand=None) + spawn_mask = (jnp.arange(self.consts.MAX_ENEMY_CARS) == spawn_idx) & should_spawn & has_inactive + + spawn_offset = self.consts.ENEMY_OFFSCREEN_SPAWN_OFFSET + active_count * self.consts.ENEMY_MIN_SPAWN_GAP + jax.random.uniform(key_spawn_offset, minval=0.0, maxval=40.0) + spawn_side = jax.random.choice(key_spawn_side, jnp.array([-1.0, 1.0])) + raw_spawn_y = state.player_car.position.y + spawn_side * spawn_offset + spawn_y = -(((raw_spawn_y) * -1) % 1036) + spawn_road = jax.random.randint(key_spawn_direction, shape=(), minval=0, maxval=2) + + def get_road_segment(y): + segments = jnp.sum(self.consts.TRACK_CORNERS_Y > y) + return jnp.clip(segments - 1, 0, len(self.consts.TRACK_CORNERS_Y) - 2) + + segment_spawn = get_road_segment(spawn_y) + spawn_x = jax.lax.cond( + spawn_road == 0, + lambda _: self._get_x_on_road(spawn_y, segment_spawn, self.consts.FIRST_TRACK_CORNERS_X), + lambda _: self._get_x_on_road(spawn_y, segment_spawn, self.consts.SECOND_TRACK_CORNERS_X), + operand=None, + ) + + spawn_speed_mag = jax.random.randint(key_spawn_speed, shape=(), minval=self.consts.ENEMY_SPEED_MIN, maxval=self.consts.ENEMY_SPEED_MAX + 1) + spawn_speed_sign = jax.random.choice(key_spawn_sign, jnp.array([-1, 1])) + spawn_speed = spawn_speed_mag * spawn_speed_sign + spawn_type = jax.random.randint(key_spawn_type, shape=(), minval=0, maxval=4) + + direction_raw = jax.lax.cond( + spawn_road == 0, + lambda _: self.consts.FIRST_TRACK_CORNERS_X[segment_spawn+1] - self.consts.FIRST_TRACK_CORNERS_X[segment_spawn], + lambda _: self.consts.SECOND_TRACK_CORNERS_X[segment_spawn+1] - self.consts.SECOND_TRACK_CORNERS_X[segment_spawn], + operand=None, + ) + spawn_direction_x = jax.lax.cond(direction_raw > 0, lambda _: 1, lambda _: -1, operand=None) + + enemy_position_x = jnp.where(spawn_mask, spawn_x, state.enemy_cars.position.x) + enemy_position_y = jnp.where(spawn_mask, spawn_y, state.enemy_cars.position.y) + enemy_width = state.enemy_cars.position.width + enemy_height = state.enemy_cars.position.height + enemy_speed = jnp.where(spawn_mask, spawn_speed, state.enemy_cars.speed) + enemy_type = jnp.where(spawn_mask, spawn_type, state.enemy_cars.type) + enemy_current_road = jnp.where(spawn_mask, spawn_road, state.enemy_cars.current_road) + enemy_road_index_A = jnp.where(spawn_mask, segment_spawn, state.enemy_cars.road_index_A) + enemy_road_index_B = jnp.where(spawn_mask, segment_spawn, state.enemy_cars.road_index_B) + enemy_direction_x = jnp.where(spawn_mask, spawn_direction_x, state.enemy_cars.direction_x) + enemy_active = jnp.where(spawn_mask, True, state.enemy_cars.active) + enemy_age = jnp.where(spawn_mask, jnp.zeros_like(state.enemy_cars.age), state.enemy_cars.age) + + flip_keys = jax.random.split(key_flip_root, self.consts.MAX_ENEMY_CARS) + flip_mask = jax.vmap(lambda k: jax.random.uniform(k) < self.consts.ENEMY_DIRECTION_SWITCH_PROB)(flip_keys) + enemy_speed = jnp.where(jnp.logical_and(enemy_active, flip_mask), -enemy_speed, enemy_speed) + + move_fn = lambda px, py, ra, rb, cr, sp, tp: self._advance_car_core( + position_x=px, + position_y=py, + road_index_A=ra, + road_index_B=rb, + current_road=cr, + speed=sp, + is_jumping=False, + is_on_road=True, + step_counter=state.step_counter, + width=self.consts.PLAYER_SIZE[0], + height=self.consts.PLAYER_SIZE[1], + car_type=tp, + landing_check=False, + ) + + advanced_cars = jax.vmap(move_fn)( + enemy_position_x, + enemy_position_y, + enemy_road_index_A, + enemy_road_index_B, + enemy_current_road, + enemy_speed, + enemy_type, + ) + + moved_position_x = jnp.where(enemy_active, advanced_cars.position.x, enemy_position_x) + moved_position_y = jnp.where(enemy_active, advanced_cars.position.y, enemy_position_y) + moved_road_index_A = jnp.where(enemy_active, advanced_cars.road_index_A, enemy_road_index_A) + moved_road_index_B = jnp.where(enemy_active, advanced_cars.road_index_B, enemy_road_index_B) + moved_current_road = jnp.where(enemy_active, advanced_cars.current_road, enemy_current_road) + moved_direction_x = jnp.where(enemy_active, advanced_cars.direction_x, enemy_direction_x) + + enemy_age = jnp.where(enemy_active, enemy_age + 1, enemy_age) + + delta_y = moved_position_y - state.player_car.position.y + wrapped_dist = jnp.minimum(jnp.abs(delta_y), 1036 - jnp.abs(delta_y)) + far_mask = wrapped_dist > self.consts.ENEMY_DESPAWN_DISTANCE + age_mask = enemy_age > self.consts.ENEMY_MAX_AGE + despawn_mask = jnp.logical_and(enemy_active, jnp.logical_or(far_mask, age_mask)) + final_active = jnp.logical_and(enemy_active, jnp.logical_not(despawn_mask)) + enemy_age = jnp.where(despawn_mask, jnp.zeros_like(enemy_age), enemy_age) + + next_enemy_cars = EnemyCars( + position=EntityPosition( + x=moved_position_x, + y=moved_position_y, + width=enemy_width, + height=enemy_height, + ), + speed=enemy_speed, + type=enemy_type, + current_road=moved_current_road, + road_index_A=moved_road_index_A, + road_index_B=moved_road_index_B, + direction_x=moved_direction_x, + active=final_active, + age=enemy_age, + ) + + return UpNDownState( + score=state.score, + difficulty=state.difficulty, + jump_cooldown=state.jump_cooldown, + is_jumping=state.is_jumping, + is_on_road=state.is_on_road, + player_car=state.player_car, + step_counter=state.step_counter, + round_started=state.round_started, + movement_steps=state.movement_steps, + flags=state.flags, + flags_collected_mask=state.flags_collected_mask, + collectibles=state.collectibles, + collectible_spawn_timer=state.collectible_spawn_timer, + enemy_cars=next_enemy_cars, + enemy_spawn_timer=spawn_timer, ) def _passive_score_step_main(self, state: UpNDownState) -> UpNDownState: @@ -796,6 +1007,8 @@ def _passive_score_step_main(self, state: UpNDownState) -> UpNDownState: flags_collected_mask=state.flags_collected_mask, collectibles=state.collectibles, collectible_spawn_timer=state.collectible_spawn_timer, + enemy_cars=state.enemy_cars, + enemy_spawn_timer=state.enemy_spawn_timer, ) @@ -845,45 +1058,90 @@ def get_road_segment(y): type_id=jnp.zeros(self.consts.MAX_COLLECTIBLES, dtype=jnp.int32), active=jnp.zeros(self.consts.MAX_COLLECTIBLES, dtype=jnp.bool_), ) - player_car = Car( - position=EntityPosition( - x=jnp.array(30, dtype=jnp.int32), - y=jnp.array(0, dtype=jnp.int32), - width=jnp.array(self.consts.PLAYER_SIZE[0], dtype=jnp.int32), - height=jnp.array(self.consts.PLAYER_SIZE[1], dtype=jnp.int32), - ), - speed=jnp.array(0, dtype=jnp.int32), - direction_x=jnp.array(0, dtype=jnp.int32), - current_road=jnp.array(0, dtype=jnp.int32), - road_index_A=jnp.array(0, dtype=jnp.int32), - road_index_B=jnp.array(0, dtype=jnp.int32), - type=jnp.array(0, dtype=jnp.int32), - ) - state = UpNDownState( - score=jnp.array(0, dtype=jnp.int32), - difficulty=jnp.array(self.consts.DIFFICULTIES[0], dtype=jnp.int32), - jump_cooldown=jnp.array(0, dtype=jnp.int32), - is_jumping=jnp.array(False), - is_on_road=jnp.array(True), - - player_car=player_car, - step_counter=jnp.array(0, dtype=jnp.int32), - round_started=jnp.array(False), - movement_steps=jnp.array(0, dtype=jnp.int32), + def get_road_segment(y): + segments = jnp.sum(self.consts.TRACK_CORNERS_Y > y) + return jnp.clip(segments - 1, 0, len(self.consts.TRACK_CORNERS_Y) - 2) - flags=flags, - flags_collected_mask=jnp.zeros(self.consts.NUM_FLAGS, dtype=jnp.bool_), - collectibles=collectibles, - collectible_spawn_timer=jnp.array(0, dtype=jnp.int32), + # Seed initial visible enemies spaced around the player + key_init, key_type, key_road, key_speed, key_sign = jax.random.split(key, 5) + player_start_y = 0.0 + offsets = self.consts.INITIAL_ENEMY_BASE_OFFSET + self.consts.INITIAL_ENEMY_GAP * jnp.arange(self.consts.INITIAL_ENEMY_COUNT) + spawn_signs = jax.random.choice(key_sign, jnp.array([-1.0, 1.0]), shape=(self.consts.INITIAL_ENEMY_COUNT,)) + raw_spawn_y = player_start_y + spawn_signs * offsets + init_y = -(((raw_spawn_y) * -1) % 1036) + init_road = jax.random.randint(key_road, shape=(self.consts.INITIAL_ENEMY_COUNT,), minval=0, maxval=2) + init_segments = jax.vmap(get_road_segment)(init_y) + init_x = jax.vmap(lambda y, seg, road: jax.lax.cond( + road == 0, + lambda _: self._get_x_on_road(y, seg, self.consts.FIRST_TRACK_CORNERS_X), + lambda _: self._get_x_on_road(y, seg, self.consts.SECOND_TRACK_CORNERS_X), + operand=None, + ))(init_y, init_segments, init_road) + init_type = jax.random.randint(key_type, shape=(self.consts.INITIAL_ENEMY_COUNT,), minval=0, maxval=4) + init_speed_mag = jax.random.randint(key_speed, shape=(self.consts.INITIAL_ENEMY_COUNT,), minval=self.consts.ENEMY_SPEED_MIN, maxval=self.consts.ENEMY_SPEED_MAX + 1) + init_speed_sign = jax.random.choice(key_init, jnp.array([-1, 1]), shape=(self.consts.INITIAL_ENEMY_COUNT,)) + init_speed = init_speed_mag * init_speed_sign + + def init_direction(seg, road): + raw = jax.lax.cond( + road == 0, + lambda _: self.consts.FIRST_TRACK_CORNERS_X[seg+1] - self.consts.FIRST_TRACK_CORNERS_X[seg], + lambda _: self.consts.SECOND_TRACK_CORNERS_X[seg+1] - self.consts.SECOND_TRACK_CORNERS_X[seg], + operand=None, + ) + return jax.lax.cond(raw > 0, lambda _: 1, lambda _: -1, operand=None) - # -------- NEW REQUIRED FIELDS -------- - lives=jnp.array(self.consts.INITIAL_LIVES, dtype=jnp.int32), - is_dead=jnp.array(False), - respawn_timer=jnp.array(0, dtype=jnp.int32), - ) + init_dir = jax.vmap(init_direction)(init_segments, init_road) + pad = self.consts.MAX_ENEMY_CARS - self.consts.INITIAL_ENEMY_COUNT + enemy_cars = EnemyCars( + position=EntityPosition( + x=jnp.concatenate([init_x, jnp.zeros(pad, dtype=jnp.float32)]), + y=jnp.concatenate([init_y, jnp.zeros(pad, dtype=jnp.float32)]), + width=jnp.full((self.consts.MAX_ENEMY_CARS,), self.consts.PLAYER_SIZE[0]), + height=jnp.full((self.consts.MAX_ENEMY_CARS,), self.consts.PLAYER_SIZE[1]), + ), + speed=jnp.concatenate([init_speed, jnp.zeros(pad, dtype=jnp.int32)]), + type=jnp.concatenate([init_type, jnp.zeros(pad, dtype=jnp.int32)]), + current_road=jnp.concatenate([init_road, jnp.zeros(pad, dtype=jnp.int32)]), + road_index_A=jnp.concatenate([init_segments, jnp.zeros(pad, dtype=jnp.int32)]), + road_index_B=jnp.concatenate([init_segments, jnp.zeros(pad, dtype=jnp.int32)]), + direction_x=jnp.concatenate([init_dir, jnp.zeros(pad, dtype=jnp.int32)]), + active=jnp.concatenate([jnp.ones(self.consts.INITIAL_ENEMY_COUNT, dtype=jnp.bool_), jnp.zeros(pad, dtype=jnp.bool_)]), + age=jnp.concatenate([jnp.zeros(self.consts.INITIAL_ENEMY_COUNT, dtype=jnp.int32), jnp.zeros(pad, dtype=jnp.int32)]), + ) + state = UpNDownState( + score=0, + difficulty=self.consts.DIFFICULTIES[0], + jump_cooldown=0, + is_jumping=False, + is_on_road=True, + player_car=Car( + position=EntityPosition( + x=30, + y= 0, + width=self.consts.PLAYER_SIZE[0], + height=self.consts.PLAYER_SIZE[1], + ), + speed=0, + direction_x=0, + current_road=0, + road_index_A=0, + road_index_B=0, + type=0, + ), + step_counter=jnp.array(0), + round_started=jnp.array(False), + movement_steps=jnp.array(0), + flags=flags, + flags_collected_mask=jnp.zeros(self.consts.NUM_FLAGS, dtype=jnp.bool_), + collectibles=collectibles, + collectible_spawn_timer=jnp.array(0, dtype=jnp.int32), + enemy_cars=enemy_cars, + enemy_spawn_timer=jnp.array(self.consts.ENEMY_SPAWN_INTERVAL, dtype=jnp.int32), + ) initial_obs = self._get_observation(state) return initial_obs, state @@ -897,8 +1155,7 @@ def step(self, state: UpNDownState, action: chex.Array) -> Tuple[UpNDownObservat state = self._flag_step_main(state) state = self._completion_bonus_step(state) state = self._collectible_step_main(state) - - + state = self._enemy_step_main(state) done = self._get_done(state) env_reward = self._get_reward(previous_state, state) @@ -1006,6 +1263,37 @@ def __init__(self, consts: UpNDownConstants = None): repeats = max(1, math.ceil(self.view_height / road_cycle) + 2) self._road_tile_offsets = jnp.arange(-repeats, repeats + 1, dtype=jnp.int32) * jnp.int32(self.complete_road_size) self._num_road_tiles = int(self._road_tile_offsets.shape[0]) + + self.enemy_sprite_names = { + self.consts.ENEMY_TYPE_CAMERO: ("camero_left", "camero_right"), + self.consts.ENEMY_TYPE_FLAG_CARRIER: ("flag_carrier_left", "flag_carrier_right"), + self.consts.ENEMY_TYPE_PICKUP: ("pick_up_truck_left", "pick_up_truck_right"), + self.consts.ENEMY_TYPE_TRUCK: ("truck_left", "truck_right"), + } + + # Pre-pad enemy masks to a common shape so switch/array indexing works under jit + enemy_left_raw = [ + self.SHAPE_MASKS["camero_left"], + self.SHAPE_MASKS["flag_carrier_left"], + self.SHAPE_MASKS["pick_up_truck_left"], + self.SHAPE_MASKS["truck_left"], + ] + enemy_right_raw = [ + self.SHAPE_MASKS["camero_right"], + self.SHAPE_MASKS["flag_carrier_right"], + self.SHAPE_MASKS["pick_up_truck_right"], + self.SHAPE_MASKS["truck_right"], + ] + max_h = max([m.shape[0] for m in enemy_left_raw + enemy_right_raw]) + max_w = max([m.shape[1] for m in enemy_left_raw + enemy_right_raw]) + + def _pad_mask(mask): + pad_h = max_h - mask.shape[0] + pad_w = max_w - mask.shape[1] + return jnp.pad(mask, ((0, pad_h), (0, pad_w)), constant_values=self.jr.TRANSPARENT_ID) + + self.enemy_left_masks = jnp.stack([_pad_mask(m) for m in enemy_left_raw], axis=0) + self.enemy_right_masks = jnp.stack([_pad_mask(m) for m in enemy_right_raw], axis=0) # Precompute flag mask data for recoloring without special-casing pink self.flag_base_mask = self.SHAPE_MASKS["pink_flag"] @@ -1081,6 +1369,14 @@ def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.n {'name': 'background', 'type': 'background', 'data': backgroundSprite}, {'name': 'road', 'type': 'group', 'files': roads}, {'name': 'player', 'type': 'single', 'file': 'player_car.npy'}, + {'name': 'camero_left', 'type': 'single', 'file': 'enemy_cars/camero_left.npy'}, + {'name': 'camero_right', 'type': 'single', 'file': 'enemy_cars/camero_right.npy'}, + {'name': 'flag_carrier_left', 'type': 'single', 'file': 'enemy_cars/flag_carrier_left.npy'}, + {'name': 'flag_carrier_right', 'type': 'single', 'file': 'enemy_cars/flag_carrier_right.npy'}, + {'name': 'pick_up_truck_left', 'type': 'single', 'file': 'enemy_cars/pick_up_truck_left.npy'}, + {'name': 'pick_up_truck_right', 'type': 'single', 'file': 'enemy_cars/pick_up_truck_right.npy'}, + {'name': 'truck_left', 'type': 'single', 'file': 'enemy_cars/truck_left.npy'}, + {'name': 'truck_right', 'type': 'single', 'file': 'enemy_cars/truck_right.npy'}, {'name': 'wall_top', 'type': 'procedural', 'data': topBlockSprite}, {'name': 'wall_bottom', 'type': 'procedural', 'data': bottomBlockSprite}, {'name': 'all_flags_top', 'type': 'single', 'file': 'all_flags_top.npy'}, @@ -1146,6 +1442,32 @@ def combine(i, acc): raster = jax.lax.fori_loop(0, total_segments, combine, raster) + def select_enemy_mask(enemy_type: chex.Array, going_left: chex.Array): + left_mask = self.enemy_left_masks[enemy_type] + right_mask = self.enemy_right_masks[enemy_type] + return jax.lax.cond(going_left, lambda _: left_mask, lambda _: right_mask, operand=None) + + def render_enemy(carry, enemy_idx): + raster = carry + enemy_active = state.enemy_cars.active[enemy_idx] + enemy_x = state.enemy_cars.position.x[enemy_idx] + enemy_y = state.enemy_cars.position.y[enemy_idx] + enemy_type = state.enemy_cars.type[enemy_idx] + direction_x = state.enemy_cars.direction_x[enemy_idx] + screen_y = 105 + (enemy_y - state.player_car.position.y) + is_visible = jnp.logical_and(enemy_active, jnp.logical_and(screen_y > 25, screen_y < 195)) + enemy_mask = select_enemy_mask(enemy_type, direction_x < 0) + + raster = jax.lax.cond( + is_visible, + lambda r: self.jr.render_at(r, enemy_x.astype(jnp.int32), screen_y.astype(jnp.int32), enemy_mask), + lambda r: r, + operand=raster, + ) + return raster, None + + raster, _ = jax.lax.scan(render_enemy, raster, jnp.arange(self.consts.MAX_ENEMY_CARS)) + player_mask = self.SHAPE_MASKS["player"] raster = self.jr.render_at(raster, state.player_car.position.x, 105, player_mask) diff --git a/src/jaxatari/games/sprites/up_n_down/enemy_cars/camero_left.npy b/src/jaxatari/games/sprites/up_n_down/enemy_cars/camero_left.npy new file mode 100644 index 0000000000000000000000000000000000000000..0ca074f09f0ebc702ce053a6b2f304d0c2650980 GIT binary patch literal 1024 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1JQ);NLqoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$-W;zOnCOQfxnmP)#3giGT22AiDD1;4Ts)UGOQ;RG{NFPidriT!pSas;; zkfg6Z=IVcB3^a@={m6PiY?wZnI4(X+9GxarA0hSV@-X|*#R>7z^}*ERqJiqs0X%Gg Y68PA((g!ma-ENpTjK(F8OPnNq0O`Z~e*gdg literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/enemy_cars/camero_right.npy b/src/jaxatari/games/sprites/up_n_down/enemy_cars/camero_right.npy new file mode 100644 index 0000000000000000000000000000000000000000..73e0114ad05cadadd23340dad61520e686889e57 GIT binary patch literal 1024 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1JQ);NLqoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$-W;zOnCOQfxnmP)#3giGT22AiDD1;4TszepRrUzLJRW-s6m^e%=E5Z{6EX*;4@TovUmtVzKQaa?2J3@~5#kf84rUHIO|E`Q^}+PxqS4(+0w1Up1;FBg U5RIY<$U#>RlgA}aNS-G80CsTue*gdg literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/enemy_cars/flag_carrier_left.npy b/src/jaxatari/games/sprites/up_n_down/enemy_cars/flag_carrier_left.npy new file mode 100644 index 0000000000000000000000000000000000000000..ff75c712527fce10581c908a23a8c6232dfd29f9 GIT binary patch literal 1152 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1JQ);NLqoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$-W;zNWY@(^7P^&-=;9|f8|A9i-Fs4c*k&v!S|M6iYGm&&5nE3R9p%VSE@36URm4QU?Yj&)KN;q%m0j#iKSusU^KD%VCrD{U^Fgq zm^_R|R|n(6XqY-&G`cz%pHv!VKB@X(>QVGxZSwjL!azZ?Ftzo=%)>>ayA!4kCJy7{ kqG9TA(YW-%#Bu2x0CAv|C;&bFV3N4RVe&8j#iKSusU^KD%VCrD{U^Fgq zm^_Ter4E-kOdpJfse{qD#9{J;XiD`HG9N|$YLnN05C&RG7RIF?CXS1S*$0z{@nJMf j92bqQ4#tPkxb(rqVd@4U4RjJ1pvMOxahN=eCQCm6WAqVj literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/enemy_cars/truck_left.npy b/src/jaxatari/games/sprites/up_n_down/enemy_cars/truck_left.npy new file mode 100644 index 0000000000000000000000000000000000000000..b7bab5aa92e91bf90abec3d22fba7a01dfa38fd9 GIT binary patch literal 1024 zcmbR27wQ`j$;eQ~P_3SlTAW;@Zl$1JQ);NLqoAIaUsO_*m=~X4l#&V(cT3DEP6dh= zXCxM+0{I$-COQg+W;zNcnmP)#3giGT1_byIWDM!`y)`j>3nTNk~6T9XgFgAIyBb_QTA4VWjS{?wO*YJM; literal 0 HcmV?d00001 From f002190f4ed0becf960cefe74291c0df33cdca08 Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sat, 20 Dec 2025 21:16:54 +0100 Subject: [PATCH 19/76] add jumping, crossing roads to game --- src/jaxatari/games/jax_upndown.py | 1343 +++++++++++++++++++---------- 1 file changed, 880 insertions(+), 463 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 648859a53..2d8de2915 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -3,6 +3,7 @@ import math from functools import partial from typing import NamedTuple, Tuple +import jax import jax.lax import jax.numpy as jnp import chex @@ -16,31 +17,37 @@ class UpNDownConstants(NamedTuple): FRAME_SKIP: int = 4 DIFFICULTIES: chex.Array = jnp.array([0, 1, 2, 3, 4, 5]) ACTION_REPEAT_PROBS: float = 0.25 - MAX_SPEED: int = 5 + MAX_SPEED: int = 6 + INITIAL_LIVES: int = 5 + RESPAWN_HIDE_FRAMES: int = 8 + JUMP_ARC_HEIGHT: float = 18.0 # Enemy spawning and movement - MAX_ENEMY_CARS: int = 6 + MAX_ENEMY_CARS: int = 8 ENEMY_SPAWN_INTERVAL: int = 80 ENEMY_DESPAWN_DISTANCE: int = 300 - ENEMY_SPEED_MIN: int = 2 + ENEMY_SPEED_MIN: int = 3 ENEMY_SPEED_MAX: int = 5 - ENEMY_DIRECTION_SWITCH_PROB: float = 0.005 + ENEMY_DIRECTION_SWITCH_PROB: float = 0.0001 ENEMY_OFFSCREEN_SPAWN_OFFSET: float = 100.0 - ENEMY_MIN_SPAWN_GAP: float = 40.0 - ENEMY_MAX_AGE: int = 900 - INITIAL_ENEMY_COUNT: int = 3 + ENEMY_MIN_SPAWN_GAP: float = 30.0 + ENEMY_MAX_AGE: int = 1900 + INITIAL_ENEMY_COUNT: int = 4 INITIAL_ENEMY_BASE_OFFSET: float = 40.0 - INITIAL_ENEMY_GAP: float = 50.0 + INITIAL_ENEMY_GAP: float = 30.0 ENEMY_TYPE_CAMERO: int = 0 ENEMY_TYPE_FLAG_CARRIER: int = 1 ENEMY_TYPE_PICKUP: int = 2 ENEMY_TYPE_TRUCK: int = 3 - JUMP_FRAMES: int = 10 - ALL_FLAGS_BONUS: int = 1000 - LANDING_ZONE: int = 15 - FIRST_ROAD_LENGTH: int = 4 - SECOND_ROAD_LENGTH: int = 4 + JUMP_FRAMES: int = 28 + POST_JUMP_DELAY: int = 10 + LANDING_TOLERANCE: int = 15 # Pixels tolerance for landing on a road (increased by 5 for off-road landings) + LATE_JUMP_COLLISION_FRAMES: int = 2 + LANDING_COLLISION_DISTANCE: float = 8.0 # Larger collision distance when landing (for crossings) + LATE_JUMP_ENEMY_SCORE: int = 400 + STEEP_ROAD_SPEED_REDUCTION_INTERVAL: int = 8 # Frames between each speed reduction on steep roads + TRACK_LENGTH: int = 1036 FIRST_TRACK_CORNERS_X: chex.Array = jnp.array([30, 75, 128, 75, 21, 75, 131, 111, 150, 95, 150, 115, 150, 108, 150, 115, 115, 75, 18, 38, 67, 38, 38, 20, 64, 30]) - TRACK_CORNERS_Y: chex.Array = jnp.array([0, -40, -98, -155, -203, -268, -327, -347, -382, -467, -525, -565, -597, -625, -670, -705, -738, -788, -838, -862, -898, -925, -950, -972, -1000, -1036]) + TRACK_CORNERS_Y: chex.Array = jnp.array([0, -40, -98, -155, -203, -268, -327, -347, -382, -467, -525, -565, -597, -625, -670, -705, -738, -788, -838, -862, -898, -925, -950, -972, -1000, -1035]) SECOND_TRACK_CORNERS_X: chex.Array = jnp.array([115, 75, 20, 75, 133, 75, 22, 37, 63, 27, 66, 30, 63, 24, 60, 38, 38, 75, 131, 111, 150, 118, 118, 98, 150, 115]) PLAYER_SIZE: Tuple[int, int] = (4, 16) INITIAL_ROAD_POS_Y: int = 25 @@ -60,7 +67,7 @@ class UpNDownConstants(NamedTuple): [78, 50, 181, 255], # Blue ]) # Top display positions for each flag (x coordinates where blackout squares appear) - FLAG_TOP_X_POSITIONS: chex.Array = jnp.array([13, 30, 47, 64, 82, 98, 118, 132]) + FLAG_TOP_X_POSITIONS: chex.Array = jnp.array([13, 30, 47, 64, 82, 98, 118, 134]) FLAG_TOP_Y: int = 20 FLAG_BLACKOUT_SIZE: Tuple[int, int] = (14, 14) # Size of blackout square FLAG_COLLECTION_SCORE: int = 75 # Points awarded for collecting a flag @@ -79,7 +86,7 @@ class UpNDownConstants(NamedTuple): COLLECTIBLE_TYPE_LOLLYPOP: int = 2 COLLECTIBLE_TYPE_ICE_CREAM: int = 3 # Collectible type spawn probabilities (must sum to 100) - COLLECTIBLE_SPAWN_PROBABILITIES: chex.Array = jnp.array([40, 20, 20, 20], dtype=jnp.int32) # Cherry: 40%, Balloon: 20%, Lollypop: 20%, IceCream: 20% + COLLECTIBLE_SPAWN_PROBABILITIES: chex.Array = jnp.array([35, 30, 25, 10], dtype=jnp.int32) # Cherry: 40%, Balloon: 20%, Lollypop: 20%, IceCream: 20% # Collectible type scores COLLECTIBLE_SCORES: chex.Array = jnp.array([50, 65, 70, 75], dtype=jnp.int32) # [cherry, balloon, lollypop, ice_cream] # Shared collectible colors @@ -139,8 +146,11 @@ class EnemyCars(NamedTuple): class UpNDownState(NamedTuple): score: chex.Array + lives: chex.Array + respawn_cooldown: chex.Array difficulty: chex.Array jump_cooldown: chex.Array + post_jump_cooldown: chex.Array is_jumping: chex.Array is_on_road: chex.Array player_car: Car @@ -150,6 +160,8 @@ class UpNDownState(NamedTuple): step_counter: chex.Array round_started: chex.Array movement_steps: chex.Array + steep_road_timer: chex.Array # Timer for steep road speed reduction + jump_slope: chex.Array # X movement per Y step, locked at jump start (float) # Flag state - tracks all 8 flags flags: Flag # Contains arrays of size NUM_FLAGS for each field flags_collected_mask: chex.Array # Boolean mask of which flag colors have been collected (size NUM_FLAGS) @@ -239,6 +251,32 @@ def _get_x_on_road(self, y: chex.Array, road_segment: chex.Array, track_corners_ operand=None, ) return x1 + t * (x2 - x1) + + @partial(jax.jit, static_argnums=(0,)) + def _get_road_segment(self, y: chex.Array) -> chex.Array: + """Return the road segment index for a given y position.""" + segments = jnp.sum(self.consts.TRACK_CORNERS_Y > y, dtype=jnp.int32) + max_idx = jnp.int32(len(self.consts.TRACK_CORNERS_Y) - 1) + return jnp.clip(segments - 1, 0, max_idx) + + @partial(jax.jit, static_argnums=(0,)) + def _is_steep_road_segment(self, current_road: chex.Array, road_index_A: chex.Array, road_index_B: chex.Array) -> chex.Array: + """Check if the current road segment is steep (no X direction change). + + A steep segment is one where the X coordinates of consecutive corners are the same, + meaning the road goes straight up/down with no horizontal movement. + + Returns True if the segment is steep (requires jump to pass when going up). + """ + # Get the X difference for the current road segment + x_diff = jax.lax.cond( + current_road == 0, + lambda _: jnp.abs(self.consts.FIRST_TRACK_CORNERS_X[road_index_A + 1] - self.consts.FIRST_TRACK_CORNERS_X[road_index_A]), + lambda _: jnp.abs(self.consts.SECOND_TRACK_CORNERS_X[road_index_B + 1] - self.consts.SECOND_TRACK_CORNERS_X[road_index_B]), + operand=None, + ) + # A segment is steep if there's no X change (or very small change) + return x_diff < 1.0 @partial(jax.jit, static_argnums=(0,)) def _isOnLine(self, state: UpNDownState, player_speed: chex.Array, turn: chex.Array) -> chex.Array: @@ -251,7 +289,7 @@ def _landing_in_water(self, state: UpNDownState, new_position_x: chex.Array, new road_B_x = ((new_position_y - self.consts.TRACK_CORNERS_Y[state.player_car.road_index_B]) / (self.consts.TRACK_CORNERS_Y[state.player_car.road_index_B+1] - self.consts.TRACK_CORNERS_Y[state.player_car.road_index_B])) * (self.consts.SECOND_TRACK_CORNERS_X[state.player_car.road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[state.player_car.road_index_B]) + self.consts.SECOND_TRACK_CORNERS_X[state.player_car.road_index_B] distance_to_road_A = jnp.abs(new_position_x - road_A_x) distance_to_road_B = jnp.abs(new_position_x - road_B_x) - landing_in_Water = jnp.logical_and(distance_to_road_A > self.consts.LANDING_ZONE, distance_to_road_B > self.consts.LANDING_ZONE) + landing_in_Water = jnp.logical_and(distance_to_road_A > self.consts.LANDING_TOLERANCE, distance_to_road_B > self.consts.LANDING_TOLERANCE) between_roads = jnp.logical_and(new_position_x > jnp.minimum(road_A_x, road_B_x), new_position_x < jnp.maximum(road_A_x, road_B_x)) return landing_in_Water, between_roads, road_A_x, road_B_x @@ -261,12 +299,12 @@ def _landing_in_water_for_indices(self, road_index_A: chex.Array, road_index_B: road_B_x = ((new_position_y - self.consts.TRACK_CORNERS_Y[road_index_B]) / (self.consts.TRACK_CORNERS_Y[road_index_B+1] - self.consts.TRACK_CORNERS_Y[road_index_B])) * (self.consts.SECOND_TRACK_CORNERS_X[road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[road_index_B]) + self.consts.SECOND_TRACK_CORNERS_X[road_index_B] distance_to_road_A = jnp.abs(new_position_x - road_A_x) distance_to_road_B = jnp.abs(new_position_x - road_B_x) - landing_in_water = jnp.logical_and(distance_to_road_A > self.consts.LANDING_ZONE, distance_to_road_B > self.consts.LANDING_ZONE) + landing_in_water = jnp.logical_and(distance_to_road_A > self.consts.LANDING_TOLERANCE, distance_to_road_B > self.consts.LANDING_TOLERANCE) between_roads = jnp.logical_and(new_position_x > jnp.minimum(road_A_x, road_B_x), new_position_x < jnp.maximum(road_A_x, road_B_x)) return landing_in_water, between_roads, road_A_x, road_B_x @partial(jax.jit, static_argnums=(0,)) - def _advance_car_core( + def _advance_player_car( self, position_x: chex.Array, position_y: chex.Array, @@ -280,18 +318,33 @@ def _advance_car_core( width: chex.Array, height: chex.Array, car_type: chex.Array, - landing_check: chex.Array, + is_landing: chex.Array, + stored_jump_slope: chex.Array, ) -> Car: - dividers = jnp.array([0, 1, 2, 4, 8, 16]) + """ + Advance the player car position. + + Jump logic: + - Car jumps in the direction of the road it's on at current speed + - While jumping, car moves freely (not constrained to road) + - On landing: check if car is on/near a road or between roads + - If between roads: snap to nearest road + - If too far from both roads (outside the road area): crash (water) + """ + # Speed-based movement timing + dividers = jnp.array([0, 1, 2, 4, 8, 16, 16]) abs_speed = jnp.abs(speed) - speed_divider = dividers[abs_speed] + speed_index = jnp.minimum(abs_speed, jnp.int32(dividers.shape[0] - 1)) + speed_divider = dividers[speed_index] effective_divider = jnp.maximum(1, speed_divider) period = jnp.maximum(1, 16 // effective_divider) half_period = jnp.maximum(1, period // 2) speed_sign = jax.lax.cond(speed != 0, lambda _: jax.lax.abs(speed) / speed, lambda _: jnp.array(0.0), operand=None) + # Get slope and intercept for current road slope, b = self._get_slope_and_intercept_from_indices(current_road, road_index_A, road_index_B) + # Determine X direction based on current road segment (for normal movement) direction_raw = jax.lax.cond( current_road == 0, lambda _: self.consts.FIRST_TRACK_CORNERS_X[road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_X[road_index_A], @@ -300,19 +353,26 @@ def _advance_car_core( ) car_direction_x = jax.lax.cond(direction_raw > 0, lambda _: 1, lambda _: -1, operand=None) + # Movement timing flags move_y = jnp.logical_and((step_counter % period) == (half_period % period), speed != 0) move_x = jnp.logical_and((step_counter % period) == 0, speed != 0) + # Step size (slightly larger at max speed) + step_size = jnp.where(speed_index >= self.consts.MAX_SPEED, 1.5, 1.0) + position = EntityPosition(x=position_x, y=position_y, width=width, height=height) + # === Y MOVEMENT === + # When jumping: move freely in Y direction + # When on road: only move if allowed by road geometry new_player_y = jax.lax.cond( move_y, lambda _: jax.lax.cond( is_jumping, - lambda _: position_y + speed_sign * -1, + lambda _: position_y + speed_sign * -step_size, # Free movement while jumping lambda _: jax.lax.cond( self._is_on_line_for_position(position, slope, b, speed_sign, 1), - lambda _: position_y + speed_sign * -1, + lambda _: position_y + speed_sign * -step_size, lambda _: jnp.array(position_y, float), operand=None, ), @@ -322,14 +382,18 @@ def _advance_car_core( operand=None, ) + # === X MOVEMENT === + # When jumping: use stored_jump_slope (locked at jump start) - moves X proportionally to Y + # The slope already encodes direction (dx/dy), so multiply by Y step size and speed_sign + # When on road: only move if allowed by road geometry new_player_x = jax.lax.cond( move_x, lambda _: jax.lax.cond( is_jumping, - lambda _: position_x + speed_sign * car_direction_x, + lambda _: position_x - speed_sign * stored_jump_slope * step_size, # Slope-based movement (negated because Y decreases going forward) lambda _: jax.lax.cond( self._is_on_line_for_position(position, slope, b, speed_sign, 2), - lambda _: position_x + speed_sign * car_direction_x, + lambda _: position_x + speed_sign * car_direction_x * step_size, # Normal road movement lambda _: jnp.array(position_x, float), operand=None, ), @@ -339,44 +403,76 @@ def _advance_car_core( operand=None, ) - landing_in_water, between_roads, road_A_x, road_B_x = self._landing_in_water_for_indices(road_index_A, road_index_B, new_player_x, new_player_y) - landing_in_water = jnp.logical_and(landing_check, landing_in_water) - - updated_current_road = jax.lax.cond( - landing_in_water, - lambda _: 2, + # === LANDING LOGIC === + # Get the current road segment based on new Y position + segment = self._get_road_segment(new_player_y) + + # Calculate X positions of both roads at the new Y position + road_A_x = self._get_x_on_road(new_player_y, segment, self.consts.FIRST_TRACK_CORNERS_X) + road_B_x = self._get_x_on_road(new_player_y, segment, self.consts.SECOND_TRACK_CORNERS_X) + + # Calculate distances to each road + dist_to_road_A = jnp.abs(new_player_x - road_A_x) + dist_to_road_B = jnp.abs(new_player_x - road_B_x) + + # Check if player is close enough to either road (within tolerance) + on_road_A = dist_to_road_A <= self.consts.LANDING_TOLERANCE + on_road_B = dist_to_road_B <= self.consts.LANDING_TOLERANCE + on_any_road = jnp.logical_or(on_road_A, on_road_B) + + # Check if player is between the two roads + min_road_x = jnp.minimum(road_A_x, road_B_x) + max_road_x = jnp.maximum(road_A_x, road_B_x) + between_roads = jnp.logical_and(new_player_x > min_road_x, new_player_x < max_road_x) + + # Determine which road is closer + closer_to_A = dist_to_road_A < dist_to_road_B + nearest_road_x = jnp.where(closer_to_A, road_A_x, road_B_x) + nearest_road_id = jnp.where(closer_to_A, jnp.int32(0), jnp.int32(1)) + + # === LANDING OUTCOMES === + # Valid landing: on a road OR between roads (will snap to nearest) + valid_landing = jnp.logical_or(on_any_road, between_roads) + + # If landing and between roads but not directly on a road, snap to nearest road + should_snap = jnp.logical_and(is_landing, jnp.logical_and(between_roads, jnp.logical_not(on_any_road))) + final_player_x = jnp.where(should_snap, nearest_road_x, new_player_x) + + # Water landing (crash): landing outside the valid road area + landing_in_water = jnp.logical_and(is_landing, jnp.logical_not(valid_landing)) + + # === UPDATE ROAD STATE === + # Determine which road to assign on landing + landed_road = jax.lax.cond( + on_road_A, + lambda _: jnp.int32(0), lambda _: jax.lax.cond( - is_on_road, - lambda _: current_road, - lambda _: jax.lax.cond( - jnp.abs(new_player_x - road_A_x) < jnp.abs(new_player_x - road_B_x), - lambda _: 0, - lambda _: 1, - operand=None, - ), + on_road_B, + lambda _: jnp.int32(1), + lambda _: nearest_road_id, # Between roads - use nearest operand=None, ), operand=None, ) - - next_road_index_A = jax.lax.cond( - updated_current_road == 2, - lambda _: road_index_A, + + # Update current_road + # - If landing in water: set to 2 (water/crash marker) + # - If landing successfully: set to the landed road + # - If still jumping: keep current road (frozen during jump) + # - If on road normally: update based on position + updated_current_road = jax.lax.cond( + landing_in_water, + lambda _: jnp.int32(2), # Water crash lambda _: jax.lax.cond( - self.consts.TRACK_CORNERS_Y[road_index_A] < new_player_y, - lambda _: road_index_A - 1, + is_landing, + lambda _: landed_road, # Successfully landed lambda _: jax.lax.cond( - len(self.consts.TRACK_CORNERS_Y) == road_index_A + 1, + is_jumping, + lambda _: current_road, # Keep road frozen while jumping lambda _: jax.lax.cond( - self.consts.TRACK_CORNERS_Y[0] > new_player_y, - lambda _: 0, - lambda _: road_index_A, - operand=None, - ), - lambda _: jax.lax.cond( - self.consts.TRACK_CORNERS_Y[road_index_A+1] > new_player_y, - lambda _: road_index_A + 1, - lambda _: road_index_A, + current_road == 2, + lambda _: nearest_road_id, # Recover from water state + lambda _: current_road, # Normal on-road movement operand=None, ), operand=None, @@ -385,46 +481,135 @@ def _advance_car_core( ), operand=None, ) - + + # Update road indices to match current segment when not jumping + next_road_index_A = jax.lax.cond( + jnp.logical_and(jnp.logical_not(is_jumping), updated_current_road == 0), + lambda _: segment, + lambda _: road_index_A, + operand=None, + ) + next_road_index_B = jax.lax.cond( - updated_current_road == 2, + jnp.logical_and(jnp.logical_not(is_jumping), updated_current_road == 1), + lambda _: segment, lambda _: road_index_B, + operand=None, + ) + + # Wrap Y position for looping track + wrapped_y = -((new_player_y * -1) % 1036) + + return Car( + position=EntityPosition( + x=final_player_x, + y=wrapped_y, + width=width, + height=height, + ), + speed=speed, + direction_x=car_direction_x, + current_road=updated_current_road, + road_index_A=next_road_index_A, + road_index_B=next_road_index_B, + type=car_type, + ) + + @partial(jax.jit, static_argnums=(0,)) + def _advance_car_core( + self, + position_x: chex.Array, + position_y: chex.Array, + road_index_A: chex.Array, + road_index_B: chex.Array, + current_road: chex.Array, + speed: chex.Array, + step_counter: chex.Array, + width: chex.Array, + height: chex.Array, + car_type: chex.Array, + ) -> Car: + """Simplified car advancement for enemy cars (no jumping/landing logic).""" + dividers = jnp.array([0, 1, 2, 4, 8, 16, 16]) + abs_speed = jnp.abs(speed) + speed_index = jnp.minimum(abs_speed, jnp.int32(dividers.shape[0] - 1)) + speed_divider = dividers[speed_index] + effective_divider = jnp.maximum(1, speed_divider) + period = jnp.maximum(1, 16 // effective_divider) + half_period = jnp.maximum(1, period // 2) + speed_sign = jax.lax.cond(speed != 0, lambda _: jax.lax.abs(speed) / speed, lambda _: jnp.array(0.0), operand=None) + + slope, b = self._get_slope_and_intercept_from_indices(current_road, road_index_A, road_index_B) + + direction_raw = jax.lax.cond( + current_road == 0, + lambda _: self.consts.FIRST_TRACK_CORNERS_X[road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_X[road_index_A], + lambda _: self.consts.SECOND_TRACK_CORNERS_X[road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[road_index_B], + operand=None, + ) + car_direction_x = jax.lax.cond(direction_raw > 0, lambda _: 1, lambda _: -1, operand=None) + + move_y = jnp.logical_and((step_counter % period) == (half_period % period), speed != 0) + move_x = jnp.logical_and((step_counter % period) == 0, speed != 0) + + step_size = jnp.where(speed_index >= self.consts.MAX_SPEED, 1.5, 1.0) + + position = EntityPosition(x=position_x, y=position_y, width=width, height=height) + + new_y = jax.lax.cond( + move_y, lambda _: jax.lax.cond( - self.consts.TRACK_CORNERS_Y[road_index_B] < new_player_y, - lambda _: road_index_B - 1, - lambda _: jax.lax.cond( - len(self.consts.TRACK_CORNERS_Y) == road_index_B + 1, - lambda _: jax.lax.cond( - self.consts.TRACK_CORNERS_Y[0] > new_player_y, - lambda _: 0, - lambda _: road_index_B, - operand=None, - ), - lambda _: jax.lax.cond( - self.consts.TRACK_CORNERS_Y[road_index_B+1] > new_player_y, - lambda _: road_index_B + 1, - lambda _: road_index_B, - operand=None, - ), - operand=None, - ), + self._is_on_line_for_position(position, slope, b, speed_sign, 1), + lambda _: position_y + speed_sign * -step_size, + lambda _: jnp.array(position_y, float), + operand=None, + ), + lambda _: jnp.array(position_y, float), + operand=None, + ) + + new_x = jax.lax.cond( + move_x, + lambda _: jax.lax.cond( + self._is_on_line_for_position(position, slope, b, speed_sign, 2), + lambda _: position_x + speed_sign * car_direction_x * step_size, + lambda _: jnp.array(position_x, float), operand=None, ), + lambda _: jnp.array(position_x, float), operand=None, ) - wrapped_y = -((new_player_y * -1) % 1036) + wrapped_y = -((new_y * -1) % 1036) + + # Update road segment indices based on new position + segment_from_y = self._get_road_segment(new_y) + + # Update road indices to track the current segment + next_road_index_A = jax.lax.cond( + current_road == 0, + lambda _: segment_from_y, + lambda _: road_index_A, + operand=None, + ) + + next_road_index_B = jax.lax.cond( + current_road == 1, + lambda _: segment_from_y, + lambda _: road_index_B, + operand=None, + ) return Car( position=EntityPosition( - x=new_player_x, + x=new_x, y=wrapped_y, width=width, height=height, ), speed=speed, direction_x=car_direction_x, - current_road=updated_current_road, + current_road=current_road, road_index_A=next_road_index_A, road_index_B=next_road_index_B, type=car_type, @@ -563,11 +748,7 @@ def select_type(rand_val): type_id_spawn = select_type(rand_type) # Calculate X position on road - def get_road_segment(y): - segments = jnp.sum(self.consts.TRACK_CORNERS_Y > y) - return jnp.clip(segments - 1, 0, len(self.consts.TRACK_CORNERS_Y) - 2) - - segment_spawn = get_road_segment(y_spawn) + segment_spawn = self._get_road_segment(y_spawn) x_spawn = jax.lax.cond( road_spawn == 0, lambda _: self._get_x_on_road(y_spawn, segment_spawn, self.consts.FIRST_TRACK_CORNERS_X), @@ -716,20 +897,145 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: operand=player_speed, ) - is_jumping = jnp.logical_or(jnp.logical_and(state.is_jumping, state.jump_cooldown > 0), jnp.logical_and(state.is_on_road, jnp.logical_and(player_speed >= 0, jnp.logical_and(state.jump_cooldown == 0, jump)))) + # Check if on a steep road section (no X direction change) and apply speed reduction + # This simulates steep road sections that require a jump to pass when going upward + is_on_steep_road = self._is_steep_road_segment( + state.player_car.current_road, + state.player_car.road_index_A, + state.player_car.road_index_B, + ) + # Only apply steep road penalty when: + # 1. Player is on a steep road section + # 2. Player is not jumping + # 3. Player has positive speed (going upward) + on_steep_going_up = jnp.logical_and( + is_on_steep_road, + jnp.logical_and( + jnp.logical_not(state.is_jumping), + player_speed > 0 + ) + ) + # Update steep road timer - increment when on steep road going up, reset otherwise + steep_road_timer = jax.lax.cond( + on_steep_going_up, + lambda _: state.steep_road_timer + 1, + lambda _: jnp.array(0, dtype=jnp.int32), + operand=None, + ) + # Only reduce speed when timer reaches the interval threshold + should_reduce_speed = jnp.logical_and( + on_steep_going_up, + steep_road_timer >= self.consts.STEEP_ROAD_SPEED_REDUCTION_INTERVAL + ) + # Gradually reduce speed toward -2 when on steep section without jumping + player_speed = jax.lax.cond( + should_reduce_speed, + lambda s: jnp.maximum(s - 1, jnp.int32(-2)), + lambda s: s, + operand=player_speed, + ) + # Reset timer after speed reduction + steep_road_timer = jax.lax.cond( + should_reduce_speed, + lambda _: jnp.array(0, dtype=jnp.int32), + lambda _: steep_road_timer, + operand=None, + ) + + can_start_jump = jnp.logical_and(state.jump_cooldown == 0, state.post_jump_cooldown == 0) + is_jumping = jnp.logical_or( + jnp.logical_and(state.is_jumping, state.jump_cooldown > 0), + jnp.logical_and(state.is_on_road, jnp.logical_and(player_speed >= 0, jnp.logical_and(can_start_jump, jump))), + ) + + # Detect when a new jump is starting (was not jumping, now is jumping) + starting_jump = jnp.logical_and(is_jumping, jnp.logical_not(state.is_jumping)) + + # Calculate jump slope at jump start (X change per Y step) + # Uses the road segment slope to follow the road trajectory + road_index = jax.lax.cond( + state.player_car.current_road == 0, + lambda _: state.player_car.road_index_A, + lambda _: state.player_car.road_index_B, + operand=None, + ) + + # Get corner coordinates for the current segment + # Segment goes from corner[road_index] to corner[road_index+1] + start_x = jax.lax.cond( + state.player_car.current_road == 0, + lambda _: self.consts.FIRST_TRACK_CORNERS_X[road_index], + lambda _: self.consts.SECOND_TRACK_CORNERS_X[road_index], + operand=None, + ) + end_x = jax.lax.cond( + state.player_car.current_road == 0, + lambda _: self.consts.FIRST_TRACK_CORNERS_X[road_index +1], + lambda _: self.consts.SECOND_TRACK_CORNERS_X[road_index +1], + operand=None, + ) + start_y = self.consts.TRACK_CORNERS_Y[road_index] + + end_y = jax.lax.cond( + jnp.equal(self.consts.FIRST_TRACK_CORNERS_X[road_index + 1], self.consts.FIRST_TRACK_CORNERS_X[road_index + 2]), + lambda _: self.consts.TRACK_CORNERS_Y[road_index + 2], + lambda _: self.consts.TRACK_CORNERS_Y[road_index + 1], + operand=None + ) + + # Calculate slope: how much X changes per unit Y change + delta_x = end_x - start_x + delta_y = end_y - start_y + # Avoid division by zero for horizontal segments + new_jump_slope = jax.lax.cond( + jnp.abs(delta_y) > 0.001, + lambda _: jnp.float32(delta_x) / jnp.float32(delta_y), + lambda _: jnp.float32(0.0), + operand=None, + ) + + # Lock slope at jump start, keep previous slope during jump + jump_slope = jax.lax.cond( + starting_jump, + lambda _: new_jump_slope, + lambda _: state.jump_slope, + operand=None, + ) + jump_cooldown = jax.lax.cond( state.jump_cooldown > 0, lambda s: s - 1, - lambda s: jax.lax.cond(is_jumping, - lambda _: self.consts.JUMP_FRAMES, - lambda _: 0, - operand=None), + lambda s: jax.lax.cond( + is_jumping, + lambda _: self.consts.JUMP_FRAMES, + lambda _: 0, + operand=None, + ), operand=state.jump_cooldown, ) + + post_jump_cooldown = jax.lax.cond( + jnp.logical_and(state.jump_cooldown == 1, jump_cooldown == 0), + lambda _: self.consts.POST_JUMP_DELAY, + lambda _: jax.lax.cond( + state.post_jump_cooldown > 0, + lambda s: s - 1, + lambda s: s, + operand=state.post_jump_cooldown, + ), + operand=None, + ) is_on_road = ~is_jumping is_landing = jnp.logical_and(state.jump_cooldown == 1, jump_cooldown == 0) - updated_player_car = self._advance_car_core( + respawn_cooldown = jax.lax.cond( + state.respawn_cooldown > 0, + lambda _: state.respawn_cooldown - 1, + lambda _: jnp.array(0, dtype=jnp.int32), + operand=None, + ) + + updated_player_car = self._advance_player_car( position_x=state.player_car.position.x, position_y=state.player_car.position.y, road_index_A=state.player_car.road_index_A, @@ -742,30 +1048,41 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: width=state.player_car.position.width, height=state.player_car.position.height, car_type=state.player_car.type, - landing_check=is_landing, + is_landing=is_landing, + stored_jump_slope=jump_slope, ) - return UpNDownState( - score=state.score, - difficulty=state.difficulty, + # Check if a speed-changing action (UP or DOWN) was taken + speed_action_taken = jnp.logical_or(up, down) + # Round starts only after a speed-changing action + round_started_now = jnp.logical_or(state.round_started, speed_action_taken) + + next_state = state._replace( + respawn_cooldown=respawn_cooldown, jump_cooldown=jump_cooldown, + post_jump_cooldown=post_jump_cooldown, is_jumping=is_jumping, is_on_road=is_on_road, player_car=updated_player_car, step_counter=state.step_counter + 1, - round_started=jnp.logical_or(state.round_started, player_speed != 0), + round_started=round_started_now, movement_steps=jax.lax.cond( - jnp.logical_or(state.round_started, player_speed != 0), - lambda s: state.movement_steps + 1, - lambda s: state.movement_steps, + round_started_now, + lambda _: state.movement_steps + 1, + lambda _: state.movement_steps, operand=None, ), - flags=state.flags, - flags_collected_mask=state.flags_collected_mask, - collectibles=state.collectibles, - collectible_spawn_timer=state.collectible_spawn_timer, - enemy_cars=state.enemy_cars, - enemy_spawn_timer=state.enemy_spawn_timer, + steep_road_timer=steep_road_timer, + jump_slope=jump_slope, + ) + + water_crash = jnp.logical_and(is_landing, updated_player_car.current_road == 2) + + return jax.lax.cond( + water_crash, + lambda _: self._respawn_after_collision(next_state, next_state.lives - 1), + lambda _: next_state, + operand=None, ) def _flag_step_main(self, state: UpNDownState) -> UpNDownState: @@ -778,25 +1095,10 @@ def _flag_step_main(self, state: UpNDownState) -> UpNDownState: state, new_player_y, player_x, current_road ) - return UpNDownState( + return state._replace( score=state.score + flag_score, - difficulty=state.difficulty, - jump_cooldown=state.jump_cooldown, - is_jumping=state.is_jumping, - is_on_road=state.is_on_road, - player_car=state.player_car, - lives=state.lives, - is_dead=state.is_dead, - respawn_timer=state.respawn_timer, - step_counter=state.step_counter, - round_started=state.round_started, - movement_steps=state.movement_steps, flags=new_flags, flags_collected_mask=new_flags_collected_mask, - collectibles=state.collectibles, - collectible_spawn_timer=state.collectible_spawn_timer, - enemy_cars=state.enemy_cars, - enemy_spawn_timer=state.enemy_spawn_timer, ) def _collectible_step_main(self, state: UpNDownState) -> UpNDownState: @@ -809,25 +1111,76 @@ def _collectible_step_main(self, state: UpNDownState) -> UpNDownState: state, new_player_y, player_x, current_road ) - return UpNDownState( + return state._replace( score=state.score + collectible_score, - difficulty=state.difficulty, - jump_cooldown=state.jump_cooldown, - is_jumping=state.is_jumping, - is_on_road=state.is_on_road, - player_car=state.player_car, - lives=state.lives, - is_dead=state.is_dead, - respawn_timer=state.respawn_timer, - step_counter=state.step_counter, - round_started=state.round_started, - movement_steps=state.movement_steps, - flags=state.flags, - flags_collected_mask=state.flags_collected_mask, collectibles=updated_collectibles, collectible_spawn_timer=new_collectible_timer, - enemy_cars=state.enemy_cars, - enemy_spawn_timer=state.enemy_spawn_timer, + ) + + def _initialize_collectibles(self) -> Collectible: + """Return a cleared collectible pool.""" + return Collectible( + y=jnp.zeros(self.consts.MAX_COLLECTIBLES), + x=jnp.zeros(self.consts.MAX_COLLECTIBLES), + road=jnp.zeros(self.consts.MAX_COLLECTIBLES, dtype=jnp.int32), + color_idx=jnp.zeros(self.consts.MAX_COLLECTIBLES, dtype=jnp.int32), + type_id=jnp.zeros(self.consts.MAX_COLLECTIBLES, dtype=jnp.int32), + active=jnp.zeros(self.consts.MAX_COLLECTIBLES, dtype=jnp.bool_), + ) + + @partial(jax.jit, static_argnums=(0,)) + def _initialize_enemies(self, key: chex.Array, player_start_y: chex.Array) -> EnemyCars: + """Seed the initial set of visible enemies around the player.""" + key_init, key_type, key_road, key_speed, key_sign = jax.random.split(key, 5) + + offsets = self.consts.INITIAL_ENEMY_BASE_OFFSET + self.consts.INITIAL_ENEMY_GAP * jnp.arange(self.consts.INITIAL_ENEMY_COUNT) + spawn_signs = jax.random.choice(key_sign, jnp.array([-1.0, 1.0]), shape=(self.consts.INITIAL_ENEMY_COUNT,)) + raw_spawn_y = player_start_y + spawn_signs * offsets + init_y = -(((raw_spawn_y) * -1) % self.consts.TRACK_LENGTH) + init_road = jax.random.randint(key_road, shape=(self.consts.INITIAL_ENEMY_COUNT,), minval=0, maxval=2) + + init_segments = jax.vmap(self._get_road_segment)(init_y) + + init_x = jax.vmap(lambda y, seg, road: jax.lax.cond( + road == 0, + lambda _: self._get_x_on_road(y, seg, self.consts.FIRST_TRACK_CORNERS_X), + lambda _: self._get_x_on_road(y, seg, self.consts.SECOND_TRACK_CORNERS_X), + operand=None, + ))(init_y, init_segments, init_road) + + init_type = jax.random.randint(key_type, shape=(self.consts.INITIAL_ENEMY_COUNT,), minval=0, maxval=4) + init_speed_mag = jax.random.randint(key_speed, shape=(self.consts.INITIAL_ENEMY_COUNT,), minval=self.consts.ENEMY_SPEED_MIN, maxval=self.consts.ENEMY_SPEED_MAX + 1) + init_speed_sign = jax.random.choice(key_init, jnp.array([-1, 1]), shape=(self.consts.INITIAL_ENEMY_COUNT,)) + init_speed = init_speed_mag * init_speed_sign + + def init_direction(seg, road): + raw = jax.lax.cond( + road == 0, + lambda _: self.consts.FIRST_TRACK_CORNERS_X[seg+1] - self.consts.FIRST_TRACK_CORNERS_X[seg], + lambda _: self.consts.SECOND_TRACK_CORNERS_X[seg+1] - self.consts.SECOND_TRACK_CORNERS_X[seg], + operand=None, + ) + return jax.lax.cond(raw > 0, lambda _: 1, lambda _: -1, operand=None) + + init_dir = jax.vmap(init_direction)(init_segments, init_road) + + pad = self.consts.MAX_ENEMY_CARS - self.consts.INITIAL_ENEMY_COUNT + + return EnemyCars( + position=EntityPosition( + x=jnp.concatenate([init_x, jnp.zeros(pad, dtype=jnp.float32)]), + y=jnp.concatenate([init_y, jnp.zeros(pad, dtype=jnp.float32)]), + width=jnp.full((self.consts.MAX_ENEMY_CARS,), self.consts.PLAYER_SIZE[0]), + height=jnp.full((self.consts.MAX_ENEMY_CARS,), self.consts.PLAYER_SIZE[1]), + ), + speed=jnp.concatenate([init_speed, jnp.zeros(pad, dtype=jnp.int32)]), + type=jnp.concatenate([init_type, jnp.zeros(pad, dtype=jnp.int32)]), + current_road=jnp.concatenate([init_road, jnp.zeros(pad, dtype=jnp.int32)]), + road_index_A=jnp.concatenate([init_segments, jnp.zeros(pad, dtype=jnp.int32)]), + road_index_B=jnp.concatenate([init_segments, jnp.zeros(pad, dtype=jnp.int32)]), + direction_x=jnp.concatenate([init_dir, jnp.zeros(pad, dtype=jnp.int32)]), + active=jnp.concatenate([jnp.ones(self.consts.INITIAL_ENEMY_COUNT, dtype=jnp.bool_), jnp.zeros(pad, dtype=jnp.bool_)]), + age=jnp.concatenate([jnp.zeros(self.consts.INITIAL_ENEMY_COUNT, dtype=jnp.int32), jnp.zeros(pad, dtype=jnp.int32)]), ) @partial(jax.jit, static_argnums=(0,)) @@ -861,11 +1214,7 @@ def _enemy_step_main(self, state: UpNDownState) -> UpNDownState: spawn_y = -(((raw_spawn_y) * -1) % 1036) spawn_road = jax.random.randint(key_spawn_direction, shape=(), minval=0, maxval=2) - def get_road_segment(y): - segments = jnp.sum(self.consts.TRACK_CORNERS_Y > y) - return jnp.clip(segments - 1, 0, len(self.consts.TRACK_CORNERS_Y) - 2) - - segment_spawn = get_road_segment(spawn_y) + segment_spawn = self._get_road_segment(spawn_y) spawn_x = jax.lax.cond( spawn_road == 0, lambda _: self._get_x_on_road(spawn_y, segment_spawn, self.consts.FIRST_TRACK_CORNERS_X), @@ -910,13 +1259,10 @@ def get_road_segment(y): road_index_B=rb, current_road=cr, speed=sp, - is_jumping=False, - is_on_road=True, step_counter=state.step_counter, width=self.consts.PLAYER_SIZE[0], height=self.consts.PLAYER_SIZE[1], car_type=tp, - landing_check=False, ) advanced_cars = jax.vmap(move_fn)( @@ -963,22 +1309,155 @@ def get_road_segment(y): age=enemy_age, ) + return state._replace( + enemy_cars=next_enemy_cars, + enemy_spawn_timer=spawn_timer, + ) + + @partial(jax.jit, static_argnums=(0,)) + def _respawn_after_collision(self, state: UpNDownState, new_lives: chex.Array) -> UpNDownState: + """Respawn the player on a random road while preserving score and flags.""" + base_key = jax.random.PRNGKey(1337) + key_spawn = jax.random.fold_in(base_key, state.step_counter) + road_key, enemy_key = jax.random.split(key_spawn, 2) + + player_start_y = jnp.array(0.0) + start_segment = jnp.array(0, dtype=jnp.int32) + respawn_road = jax.random.randint(road_key, shape=(), minval=0, maxval=2) + + start_x = jax.lax.cond( + respawn_road == 0, + lambda _: self._get_x_on_road(player_start_y, start_segment, self.consts.FIRST_TRACK_CORNERS_X), + lambda _: self._get_x_on_road(player_start_y, start_segment, self.consts.SECOND_TRACK_CORNERS_X), + operand=None, + ) + + enemy_cars = self._initialize_enemies(enemy_key, player_start_y) + collectibles = self._initialize_collectibles() + + player_car = Car( + position=EntityPosition( + x=jnp.asarray(start_x, dtype=jnp.float32), + y=jnp.asarray(player_start_y, dtype=jnp.float32), + width=self.consts.PLAYER_SIZE[0], + height=self.consts.PLAYER_SIZE[1], + ), + speed=jnp.array(0, dtype=jnp.int32), + direction_x=jnp.array(0, dtype=jnp.int32), + current_road=respawn_road, + road_index_A=start_segment, + road_index_B=start_segment, + type=jnp.array(0, dtype=jnp.int32), + ) + return UpNDownState( score=state.score, + lives=new_lives, + respawn_cooldown=jnp.array(self.consts.RESPAWN_HIDE_FRAMES, dtype=jnp.int32), difficulty=state.difficulty, - jump_cooldown=state.jump_cooldown, - is_jumping=state.is_jumping, - is_on_road=state.is_on_road, - player_car=state.player_car, + jump_cooldown=jnp.array(0, dtype=jnp.int32), + post_jump_cooldown=jnp.array(0, dtype=jnp.int32), + is_jumping=jnp.array(False), + is_on_road=jnp.array(True), + player_car=player_car, step_counter=state.step_counter, - round_started=state.round_started, - movement_steps=state.movement_steps, + round_started=jnp.array(False), + movement_steps=jnp.array(0), + steep_road_timer=jnp.array(0, dtype=jnp.int32), + jump_slope=jnp.array(0.0, dtype=jnp.float32), flags=state.flags, flags_collected_mask=state.flags_collected_mask, - collectibles=state.collectibles, - collectible_spawn_timer=state.collectible_spawn_timer, - enemy_cars=next_enemy_cars, - enemy_spawn_timer=spawn_timer, + collectibles=collectibles, + collectible_spawn_timer=jnp.array(0, dtype=jnp.int32), + enemy_cars=enemy_cars, + enemy_spawn_timer=jnp.array(self.consts.ENEMY_SPAWN_INTERVAL, dtype=jnp.int32), + ) + + @partial(jax.jit, static_argnums=(0,)) + def _enemy_collision_step_main(self, state: UpNDownState) -> UpNDownState: + """Handle collisions between the player and enemy cars. + + - While airborne, collisions are ignored except for the final jump frames, + where hitting an enemy despawns it and awards a bonus. + - On ground collisions, the player loses a life and the stage soft-resets + without clearing score or collected flags. + - Landing collisions use a larger distance and are road-independent (for crossings). + """ + player_x = state.player_car.position.x + player_y = state.player_car.position.y + + dx = jnp.abs(state.enemy_cars.position.x - player_x) + dy = jnp.abs(state.enemy_cars.position.y - player_y) + wrapped_dy = jnp.minimum(dy, self.consts.TRACK_LENGTH - dy) + + # For ground collision: only trigger when enemy position is within 3 pixels + ground_collision_distance = 3.0 + overlap_x_ground = dx <= ground_collision_distance + overlap_y_ground = wrapped_dy <= ground_collision_distance + # For landing collision: use larger distance and road-independent (for crossings) + overlap_x_landing = dx <= self.consts.LANDING_COLLISION_DISTANCE + overlap_y_landing = wrapped_dy <= self.consts.LANDING_COLLISION_DISTANCE + # For late jump collision: use original larger overlap based on car dimensions + overlap_x_jump = dx <= (state.player_car.position.width + state.enemy_cars.position.width) / 2.0 + overlap_y_jump = wrapped_dy <= (state.player_car.position.height + state.enemy_cars.position.height) / 2.0 + same_road = state.enemy_cars.current_road == state.player_car.current_road + + # Ground collision mask uses tight 3-pixel distance and same road + ground_collision_mask = jnp.logical_and(state.enemy_cars.active, jnp.logical_and(same_road, jnp.logical_and(overlap_x_ground, overlap_y_ground))) + # Landing collision mask uses larger distance and is road-independent (for crossings) + landing_collision_mask = jnp.logical_and(state.enemy_cars.active, jnp.logical_and(overlap_x_landing, overlap_y_landing)) + # Jump collision mask uses original larger overlap (for scoring when jumping on enemies) + jump_collision_mask = jnp.logical_and(state.enemy_cars.active, jnp.logical_and(same_road, jnp.logical_and(overlap_x_jump, overlap_y_jump))) + collision_mask = jump_collision_mask # For late jump scoring + + any_jump_collision = jnp.any(jump_collision_mask) + any_ground_collision = jnp.any(ground_collision_mask) + any_landing_collision = jnp.any(landing_collision_mask) + + # Check if player is in the landing phase (just landed from a jump) + is_landing_phase = jnp.logical_and(state.post_jump_cooldown > 0, state.post_jump_cooldown <= self.consts.POST_JUMP_DELAY) + + late_jump_window = jnp.logical_and(state.is_jumping, state.jump_cooldown <= self.consts.LATE_JUMP_COLLISION_FRAMES) + late_jump_collision = jnp.logical_and(any_jump_collision, late_jump_window) + grounded_collision = jnp.logical_and(any_ground_collision, jnp.logical_not(state.is_jumping)) + # Landing collision is road-independent and uses larger distance + landing_collision = jnp.logical_and(any_landing_collision, is_landing_phase) + + def handle_late_jump(): + hits = collision_mask.astype(jnp.int32) + bonus = jnp.sum(hits) * self.consts.LATE_JUMP_ENEMY_SCORE + new_enemy_active = jnp.logical_and(state.enemy_cars.active, jnp.logical_not(collision_mask)) + new_enemy_age = jnp.where(collision_mask, jnp.zeros_like(state.enemy_cars.age), state.enemy_cars.age) + new_enemy_cars = EnemyCars( + position=state.enemy_cars.position, + speed=state.enemy_cars.speed, + type=state.enemy_cars.type, + current_road=state.enemy_cars.current_road, + road_index_A=state.enemy_cars.road_index_A, + road_index_B=state.enemy_cars.road_index_B, + direction_x=state.enemy_cars.direction_x, + active=new_enemy_active, + age=new_enemy_age, + ) + + return state._replace(score=state.score + bonus, enemy_cars=new_enemy_cars) + + def handle_ground_collision(): + return self._respawn_after_collision(state, state.lives - 1) + + # Check for any collision that should cause respawn (ground or landing) + any_fatal_collision = jnp.logical_or(grounded_collision, landing_collision) + + return jax.lax.cond( + late_jump_collision, + lambda _: handle_late_jump(), + lambda _: jax.lax.cond( + any_fatal_collision, + lambda _: handle_ground_collision(), + lambda _: state, + operand=None, + ), + operand=None, ) def _passive_score_step_main(self, state: UpNDownState) -> UpNDownState: @@ -990,26 +1469,7 @@ def _passive_score_step_main(self, state: UpNDownState) -> UpNDownState: operand=None, ) - return UpNDownState( - score=state.score + bonus, - difficulty=state.difficulty, - jump_cooldown=state.jump_cooldown, - is_jumping=state.is_jumping, - is_on_road=state.is_on_road, - player_car=state.player_car, - lives=state.lives, - is_dead=state.is_dead, - respawn_timer=state.respawn_timer, - step_counter=state.step_counter, - round_started=state.round_started, - movement_steps=state.movement_steps, - flags=state.flags, - flags_collected_mask=state.flags_collected_mask, - collectibles=state.collectibles, - collectible_spawn_timer=state.collectible_spawn_timer, - enemy_cars=state.enemy_cars, - enemy_spawn_timer=state.enemy_spawn_timer, - ) + return state._replace(score=state.score + bonus) def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: @@ -1018,25 +1478,17 @@ def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: if key is None: key = jax.random.PRNGKey(42) + key, flag_key, enemy_key = jax.random.split(key, 3) # Evenly spread flags along the track with small jitter - key, subkey = jax.random.split(key) base_y = jnp.linspace(-900.0, -100.0, self.consts.NUM_FLAGS) - jitter = jax.random.uniform(subkey, shape=(self.consts.NUM_FLAGS,), minval=-40.0, maxval=40.0) + jitter = jax.random.uniform(flag_key, shape=(self.consts.NUM_FLAGS,), minval=-40.0, maxval=40.0) flag_y_offsets = base_y + jitter # Alternate roads 0/1 for variety flag_roads = jnp.arange(self.consts.NUM_FLAGS) % 2 - - - # Calculate which road segment each flag is on based on Y position - def get_road_segment(y): - # Find the segment where TRACK_CORNERS_Y[i] > y >= TRACK_CORNERS_Y[i+1] - segments = jnp.sum(self.consts.TRACK_CORNERS_Y > y) - return jnp.clip(segments - 1, 0, len(self.consts.TRACK_CORNERS_Y) - 2) - - flag_segments = jax.vmap(get_road_segment)(flag_y_offsets) + flag_segments = jax.vmap(self._get_road_segment)(flag_y_offsets) # Each flag color index corresponds to its position (0-7) flag_color_indices = jnp.arange(self.consts.NUM_FLAGS) @@ -1050,78 +1502,25 @@ def get_road_segment(y): ) # Initialize collectibles as all inactive (will spawn dynamically with mixed types) - collectibles = Collectible( - y=jnp.zeros(self.consts.MAX_COLLECTIBLES), - x=jnp.zeros(self.consts.MAX_COLLECTIBLES), - road=jnp.zeros(self.consts.MAX_COLLECTIBLES, dtype=jnp.int32), - color_idx=jnp.zeros(self.consts.MAX_COLLECTIBLES, dtype=jnp.int32), - type_id=jnp.zeros(self.consts.MAX_COLLECTIBLES, dtype=jnp.int32), - active=jnp.zeros(self.consts.MAX_COLLECTIBLES, dtype=jnp.bool_), - ) - - def get_road_segment(y): - segments = jnp.sum(self.consts.TRACK_CORNERS_Y > y) - return jnp.clip(segments - 1, 0, len(self.consts.TRACK_CORNERS_Y) - 2) + collectibles = self._initialize_collectibles() # Seed initial visible enemies spaced around the player - key_init, key_type, key_road, key_speed, key_sign = jax.random.split(key, 5) - player_start_y = 0.0 - offsets = self.consts.INITIAL_ENEMY_BASE_OFFSET + self.consts.INITIAL_ENEMY_GAP * jnp.arange(self.consts.INITIAL_ENEMY_COUNT) - spawn_signs = jax.random.choice(key_sign, jnp.array([-1.0, 1.0]), shape=(self.consts.INITIAL_ENEMY_COUNT,)) - raw_spawn_y = player_start_y + spawn_signs * offsets - init_y = -(((raw_spawn_y) * -1) % 1036) - init_road = jax.random.randint(key_road, shape=(self.consts.INITIAL_ENEMY_COUNT,), minval=0, maxval=2) - init_segments = jax.vmap(get_road_segment)(init_y) - init_x = jax.vmap(lambda y, seg, road: jax.lax.cond( - road == 0, - lambda _: self._get_x_on_road(y, seg, self.consts.FIRST_TRACK_CORNERS_X), - lambda _: self._get_x_on_road(y, seg, self.consts.SECOND_TRACK_CORNERS_X), - operand=None, - ))(init_y, init_segments, init_road) - init_type = jax.random.randint(key_type, shape=(self.consts.INITIAL_ENEMY_COUNT,), minval=0, maxval=4) - init_speed_mag = jax.random.randint(key_speed, shape=(self.consts.INITIAL_ENEMY_COUNT,), minval=self.consts.ENEMY_SPEED_MIN, maxval=self.consts.ENEMY_SPEED_MAX + 1) - init_speed_sign = jax.random.choice(key_init, jnp.array([-1, 1]), shape=(self.consts.INITIAL_ENEMY_COUNT,)) - init_speed = init_speed_mag * init_speed_sign - - def init_direction(seg, road): - raw = jax.lax.cond( - road == 0, - lambda _: self.consts.FIRST_TRACK_CORNERS_X[seg+1] - self.consts.FIRST_TRACK_CORNERS_X[seg], - lambda _: self.consts.SECOND_TRACK_CORNERS_X[seg+1] - self.consts.SECOND_TRACK_CORNERS_X[seg], - operand=None, - ) - return jax.lax.cond(raw > 0, lambda _: 1, lambda _: -1, operand=None) - - init_dir = jax.vmap(init_direction)(init_segments, init_road) - - pad = self.consts.MAX_ENEMY_CARS - self.consts.INITIAL_ENEMY_COUNT - enemy_cars = EnemyCars( - position=EntityPosition( - x=jnp.concatenate([init_x, jnp.zeros(pad, dtype=jnp.float32)]), - y=jnp.concatenate([init_y, jnp.zeros(pad, dtype=jnp.float32)]), - width=jnp.full((self.consts.MAX_ENEMY_CARS,), self.consts.PLAYER_SIZE[0]), - height=jnp.full((self.consts.MAX_ENEMY_CARS,), self.consts.PLAYER_SIZE[1]), - ), - speed=jnp.concatenate([init_speed, jnp.zeros(pad, dtype=jnp.int32)]), - type=jnp.concatenate([init_type, jnp.zeros(pad, dtype=jnp.int32)]), - current_road=jnp.concatenate([init_road, jnp.zeros(pad, dtype=jnp.int32)]), - road_index_A=jnp.concatenate([init_segments, jnp.zeros(pad, dtype=jnp.int32)]), - road_index_B=jnp.concatenate([init_segments, jnp.zeros(pad, dtype=jnp.int32)]), - direction_x=jnp.concatenate([init_dir, jnp.zeros(pad, dtype=jnp.int32)]), - active=jnp.concatenate([jnp.ones(self.consts.INITIAL_ENEMY_COUNT, dtype=jnp.bool_), jnp.zeros(pad, dtype=jnp.bool_)]), - age=jnp.concatenate([jnp.zeros(self.consts.INITIAL_ENEMY_COUNT, dtype=jnp.int32), jnp.zeros(pad, dtype=jnp.int32)]), - ) + player_start_y = jnp.array(0.0) + enemy_cars = self._initialize_enemies(enemy_key, player_start_y) state = UpNDownState( score=0, + lives=jnp.array(self.consts.INITIAL_LIVES, dtype=jnp.int32), + respawn_cooldown=jnp.array(0, dtype=jnp.int32), difficulty=self.consts.DIFFICULTIES[0], jump_cooldown=0, + post_jump_cooldown=0, is_jumping=False, is_on_road=True, player_car=Car( position=EntityPosition( - x=30, - y= 0, + x=jnp.asarray(30, dtype=jnp.float32), + y=jnp.asarray(0, dtype=jnp.float32), width=self.consts.PLAYER_SIZE[0], height=self.consts.PLAYER_SIZE[1], ), @@ -1135,6 +1534,8 @@ def init_direction(seg, road): step_counter=jnp.array(0), round_started=jnp.array(False), movement_steps=jnp.array(0), + steep_road_timer=jnp.array(0, dtype=jnp.int32), + jump_slope=jnp.array(0.0, dtype=jnp.float32), flags=flags, flags_collected_mask=jnp.zeros(self.consts.NUM_FLAGS, dtype=jnp.bool_), collectibles=collectibles, @@ -1156,6 +1557,7 @@ def step(self, state: UpNDownState, action: chex.Array) -> Tuple[UpNDownObservat state = self._completion_bonus_step(state) state = self._collectible_step_main(state) state = self._enemy_step_main(state) + state = self._enemy_collision_step_main(state) done = self._get_done(state) env_reward = self._get_reward(previous_state, state) @@ -1166,28 +1568,30 @@ def step(self, state: UpNDownState, action: chex.Array) -> Tuple[UpNDownObservat def render(self, state: UpNDownState) -> jnp.ndarray: - return self.renderer.render(state) + frame = self.renderer.render(state) + return jnp.asarray(frame, dtype=jnp.uint8) def _get_observation(self, state: UpNDownState): + # Clamp to screen-friendly coordinates so observation_space.contains passes + x = jnp.int32(jnp.clip(state.player_car.position.x, 0, 160)) + screen_y = jnp.int32(105) + player = EntityPosition( - x=jnp.array(state.player_car.position.x), - y=jnp.array(state.player_car.position.y), - width=jnp.array(self.consts.PLAYER_SIZE[0]), - height=jnp.array(self.consts.PLAYER_SIZE[1]), - ) - return UpNDownObservation( - player=player, + x=x, + y=screen_y, + width=jnp.int32(self.consts.PLAYER_SIZE[0]), + height=jnp.int32(self.consts.PLAYER_SIZE[1]), ) + return UpNDownObservation(player=player) @partial(jax.jit, static_argnums=(0,)) def obs_to_flat_array(self, obs: UpNDownObservation) -> jnp.ndarray: - return jnp.concatenate([ - obs.player.x.flatten(), - obs.player.y.flatten(), - obs.player.height.flatten(), - obs.player.width.flatten(), - ] - ) + return jnp.concatenate([ + jnp.asarray(obs.player.x, dtype=jnp.int32).reshape(-1), + jnp.asarray(obs.player.y, dtype=jnp.int32).reshape(-1), + jnp.asarray(obs.player.height, dtype=jnp.int32).reshape(-1), + jnp.asarray(obs.player.width, dtype=jnp.int32).reshape(-1), + ]) def action_space(self) -> spaces.Discrete: return spaces.Discrete(6) @@ -1212,20 +1616,19 @@ def image_space(self) -> spaces.Box: @partial(jax.jit, static_argnums=(0,)) def _get_info(self, state: UpNDownState, ) -> UpNDownInfo: - return UpNDownInfo(time=1) + return UpNDownInfo(time=jnp.asarray(state.step_counter, dtype=jnp.int32)) @partial(jax.jit, static_argnums=(0,)) def _get_reward(self, previous_state: UpNDownState, state: UpNDownState): - return state.score + base_delta = jnp.asarray(state.score - previous_state.score, dtype=jnp.float32) + if self.reward_funcs: + extras = jnp.sum(jnp.array([fn(previous_state, state) for fn in self.reward_funcs], dtype=jnp.float32)) + return base_delta + extras + return base_delta @partial(jax.jit, static_argnums=(0,)) def _get_done(self, state: UpNDownState) -> bool: - return jnp.logical_or( - state.lives <= 0, - jnp.all(state.flags_collected_mask), -) - - + return state.lives <= 0 class UpNDownRenderer(JAXGameRenderer): def __init__(self, consts: UpNDownConstants = None): @@ -1357,6 +1760,29 @@ def _compute_flag_palette_ids(self) -> jnp.ndarray: """Precompute palette indices for each flag color without special-casing pink.""" return jnp.array([self._find_palette_id(color) for color in self.consts.FLAG_COLORS], dtype=jnp.int32) + @partial(jax.jit, static_argnums=(0,)) + def _jump_arc_offset(self, jump_cooldown: chex.Array) -> chex.Array: + """Return a simple parabolic jump height based on remaining jump frames.""" + total = jnp.array(self.consts.JUMP_FRAMES, dtype=jnp.float32) + remaining = jnp.array(jump_cooldown, dtype=jnp.float32) + progress = jnp.clip((total - remaining) / jnp.maximum(total, 1.0), 0.0, 1.0) + centered = (progress - 0.5) * 2.0 + return self.consts.JUMP_ARC_HEIGHT * (1.0 - centered * centered) + + def _get_x_on_road(self, y: chex.Array, road_segment: chex.Array, track_corners_x: chex.Array) -> chex.Array: + """Linear interpolation of x along the given road segment for y.""" + y1 = self.consts.TRACK_CORNERS_Y[road_segment] + y2 = self.consts.TRACK_CORNERS_Y[road_segment + 1] + x1 = track_corners_x[road_segment] + x2 = track_corners_x[road_segment + 1] + t = jax.lax.cond( + y2 != y1, + lambda _: (y - y1) / (y2 - y1), + lambda _: 0.0, + operand=None, + ) + return x1 + t * (x2 - x1) + def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.ndarray, bottomBlockSprite: jnp.ndarray, tempPointer: jnp.ndarray, blackoutSquare: jnp.ndarray) -> tuple[list, list[str]]: """Returns the asset manifest and ordered road files.""" road_dir = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/roads" @@ -1442,223 +1868,214 @@ def combine(i, acc): raster = jax.lax.fori_loop(0, total_segments, combine, raster) - def select_enemy_mask(enemy_type: chex.Array, going_left: chex.Array): - left_mask = self.enemy_left_masks[enemy_type] - right_mask = self.enemy_right_masks[enemy_type] - return jax.lax.cond(going_left, lambda _: left_mask, lambda _: right_mask, operand=None) - - def render_enemy(carry, enemy_idx): - raster = carry - enemy_active = state.enemy_cars.active[enemy_idx] - enemy_x = state.enemy_cars.position.x[enemy_idx] - enemy_y = state.enemy_cars.position.y[enemy_idx] - enemy_type = state.enemy_cars.type[enemy_idx] - direction_x = state.enemy_cars.direction_x[enemy_idx] - screen_y = 105 + (enemy_y - state.player_car.position.y) - is_visible = jnp.logical_and(enemy_active, jnp.logical_and(screen_y > 25, screen_y < 195)) - enemy_mask = select_enemy_mask(enemy_type, direction_x < 0) - - raster = jax.lax.cond( - is_visible, - lambda r: self.jr.render_at(r, enemy_x.astype(jnp.int32), screen_y.astype(jnp.int32), enemy_mask), - lambda r: r, - operand=raster, + hide_world = state.respawn_cooldown > 0 + + # During respawn hide, only show the road/background to emulate an initial road state. + def render_roads_only(): + return raster + + def render_full_scene(): + def select_enemy_mask(enemy_type: chex.Array, going_left: chex.Array): + left_mask = self.enemy_left_masks[enemy_type] + right_mask = self.enemy_right_masks[enemy_type] + return jax.lax.cond(going_left, lambda _: left_mask, lambda _: right_mask, operand=None) + + def render_enemy(carry, enemy_idx): + raster = carry + enemy_active = state.enemy_cars.active[enemy_idx] + enemy_x = state.enemy_cars.position.x[enemy_idx] + enemy_y = state.enemy_cars.position.y[enemy_idx] + enemy_type = state.enemy_cars.type[enemy_idx] + direction_x = state.enemy_cars.direction_x[enemy_idx] + screen_y = 105 + (enemy_y - state.player_car.position.y) + is_visible = jnp.logical_and(enemy_active, jnp.logical_and(screen_y > 25, screen_y < 195)) + enemy_mask = select_enemy_mask(enemy_type, direction_x < 0) + + raster = jax.lax.cond( + is_visible, + lambda r: self.jr.render_at(r, enemy_x.astype(jnp.int32), screen_y.astype(jnp.int32), enemy_mask), + lambda r: r, + operand=raster, + ) + return raster, None + + raster_enemies, _ = jax.lax.scan(render_enemy, raster, jnp.arange(self.consts.MAX_ENEMY_CARS)) + + jump_offset = jax.lax.cond( + state.is_jumping, + lambda _: self._jump_arc_offset(state.jump_cooldown), + lambda _: jnp.array(0.0, dtype=jnp.float32), + operand=None, ) - return raster, None - raster, _ = jax.lax.scan(render_enemy, raster, jnp.arange(self.consts.MAX_ENEMY_CARS)) + player_screen_y = jnp.int32(105 - jump_offset) + player_mask = self.SHAPE_MASKS["player"] + raster_player = self.jr.render_at(raster_enemies, state.player_car.position.x, player_screen_y, player_mask) - player_mask = self.SHAPE_MASKS["player"] - raster = self.jr.render_at(raster, state.player_car.position.x, 105, player_mask) + wall_top_mask = self.SHAPE_MASKS["wall_top"] + raster_wall_top = self.jr.render_at(raster_player, 0, 0, wall_top_mask) - wall_top_mask = self.SHAPE_MASKS["wall_top"] - raster = self.jr.render_at(raster, 0, 0, wall_top_mask) + wall_bottom_mask = self.SHAPE_MASKS["wall_bottom"] + raster_wall_bottom = self.jr.render_at(raster_wall_top, 0, 210 - wall_bottom_mask.shape[0], wall_bottom_mask) - wall_bottom_mask = self.SHAPE_MASKS["wall_bottom"] - raster = self.jr.render_at(raster, 0, 210 - wall_bottom_mask.shape[0], wall_bottom_mask) + all_flags_top_mask = self.SHAPE_MASKS["all_flags_top"] + raster_flags_top = self.jr.render_at(raster_wall_bottom, 10, 20, all_flags_top_mask) - all_flags_top_mask = self.SHAPE_MASKS["all_flags_top"] - raster = self.jr.render_at(raster, 10, 20, all_flags_top_mask) + return raster_flags_top - # Render score centered at the top using dedicated score digit sprites - score_digits = self.jr.int_to_digits(state.score, max_digits=self.score_max_digits) - non_zero_mask = score_digits != 0 - has_non_zero = jnp.any(non_zero_mask) - first_non_zero = jnp.argmax(non_zero_mask) - start_index = jax.lax.select(has_non_zero, first_non_zero, self.score_max_digits - 1) - num_to_render = jax.lax.select(has_non_zero, self.score_max_digits - start_index, 1) - total_width = num_to_render * self.score_digit_spacing - score_x = self.score_center_x - (total_width // 2) + def render_rest(raster_input): + # Render score centered at the top using dedicated score digit sprites + score_digits = self.jr.int_to_digits(state.score, max_digits=self.score_max_digits) + non_zero_mask = score_digits != 0 + has_non_zero = jnp.any(non_zero_mask) + first_non_zero = jnp.argmax(non_zero_mask) + start_index = jax.lax.select(has_non_zero, first_non_zero, self.score_max_digits - 1) + num_to_render = jax.lax.select(has_non_zero, self.score_max_digits - start_index, 1) - raster = self.jr.render_label_selective( - raster, - jnp.int32(score_x), - self.score_render_y, - score_digits, - self.score_digit_masks, - start_index, - num_to_render, - spacing=self.score_digit_spacing, - max_digits_to_render=self.score_max_digits, - ) + total_width = num_to_render * self.score_digit_spacing + score_x = self.score_center_x - (total_width // 2) - # Render flags on the road - flag_pole_mask = self.SHAPE_MASKS["flag_pole"] - - def render_flag(carry, flag_idx): - raster = carry - flag_y = state.flags.y[flag_idx] - flag_road = state.flags.road[flag_idx] - flag_segment = state.flags.road_segment[flag_idx] - flag_collected = state.flags.collected[flag_idx] - flag_color_idx = state.flags.color_idx[flag_idx] - - # Calculate flag X position on its road - flag_x = jax.lax.cond( - flag_road == 0, - lambda _: self._get_flag_x_on_road(flag_y, flag_segment, self.consts.FIRST_TRACK_CORNERS_X), - lambda _: self._get_flag_x_on_road(flag_y, flag_segment, self.consts.SECOND_TRACK_CORNERS_X), - operand=None, - ) - - # Calculate screen Y position relative to player - # The player is always rendered at Y=105, so flags scroll based on player position - screen_y = 105 + (flag_y - state.player_car.position.y) - - # Check if flag is visible on screen and not collected - is_visible = jnp.logical_and( - jnp.logical_and(screen_y > 25, screen_y < 195), - ~flag_collected - ) - - # Colorize the base flag mask - color_id = self.flag_palette_ids[flag_color_idx] - colored_flag_mask = jnp.where( - self.flag_solid_mask, - color_id, - self.flag_base_mask, + raster_score = self.jr.render_label_selective( + raster_input, + jnp.int32(score_x), + self.score_render_y, + score_digits, + self.score_digit_masks, + start_index, + num_to_render, + spacing=self.score_digit_spacing, + max_digits_to_render=self.score_max_digits, ) + + # Render flags on the road + flag_pole_mask = self.SHAPE_MASKS["flag_pole"] - # Render flag if visible - raster = jax.lax.cond( - is_visible, - lambda r: self.jr.render_at( - self.jr.render_at(r, flag_x.astype(jnp.int32), screen_y.astype(jnp.int32), colored_flag_mask), - (flag_x + 5).astype(jnp.int32), screen_y.astype(jnp.int32), flag_pole_mask - ), - lambda r: r, - operand=raster, - ) - return raster, None - - raster, _ = jax.lax.scan(render_flag, raster, jnp.arange(self.consts.NUM_FLAGS)) - - # Black out collected flags at the top - blackout_mask = self.SHAPE_MASKS["blackout_square"] - - def render_blackout(carry, flag_idx): - raster = carry - flag_collected = state.flags_collected_mask[flag_idx] - blackout_x = self.consts.FLAG_TOP_X_POSITIONS[flag_idx] - blackout_y = self.consts.FLAG_TOP_Y + def render_flag(carry, flag_idx): + raster = carry + flag_y = state.flags.y[flag_idx] + flag_road = state.flags.road[flag_idx] + flag_segment = state.flags.road_segment[flag_idx] + flag_collected = state.flags.collected[flag_idx] + flag_color_idx = state.flags.color_idx[flag_idx] + + flag_x = jax.lax.cond( + flag_road == 0, + lambda _: self._get_x_on_road(flag_y, flag_segment, self.consts.FIRST_TRACK_CORNERS_X), + lambda _: self._get_x_on_road(flag_y, flag_segment, self.consts.SECOND_TRACK_CORNERS_X), + operand=None, + ) + screen_y = 105 + (flag_y - state.player_car.position.y) + is_visible = jnp.logical_and( + jnp.logical_and(screen_y > 25, screen_y < 195), + ~flag_collected + ) + color_id = self.flag_palette_ids[flag_color_idx] + colored_flag_mask = jnp.where( + self.flag_solid_mask, + color_id, + self.flag_base_mask, + ) + raster = jax.lax.cond( + is_visible, + lambda r: self.jr.render_at( + self.jr.render_at(r, flag_x.astype(jnp.int32), screen_y.astype(jnp.int32), colored_flag_mask), + (flag_x + 5).astype(jnp.int32), screen_y.astype(jnp.int32), flag_pole_mask + ), + lambda r: r, + operand=raster, + ) + return raster, None - raster = jax.lax.cond( - flag_collected, - lambda r: self.jr.render_at(r, blackout_x, blackout_y, blackout_mask), - lambda r: r, - operand=raster, - ) - return raster, None - - raster, _ = jax.lax.scan(render_blackout, raster, jnp.arange(self.consts.NUM_FLAGS)) - - # Render collectibles (unified for all types: cherry, balloon, lollypop, ice cream) - def render_collectible(carry, collectible_idx): - raster = carry - collectible_y = state.collectibles.y[collectible_idx] - collectible_x = state.collectibles.x[collectible_idx] - collectible_active = state.collectibles.active[collectible_idx] - collectible_color_idx = state.collectibles.color_idx[collectible_idx] - collectible_type_id = state.collectibles.type_id[collectible_idx] + raster_flags, _ = jax.lax.scan(render_flag, raster_score, jnp.arange(self.consts.NUM_FLAGS)) - # Calculate screen Y position relative to player - screen_y = 105 + (collectible_y - state.player_car.position.y) + blackout_mask = self.SHAPE_MASKS["blackout_square"] - # Check if collectible is visible on screen and active - is_visible = jnp.logical_and( - jnp.logical_and(screen_y > 25, screen_y < 195), - collectible_active - ) + def render_blackout(carry, flag_idx): + raster = carry + flag_collected = state.flags_collected_mask[flag_idx] + blackout_x = self.consts.FLAG_TOP_X_POSITIONS[flag_idx] + blackout_y = self.consts.FLAG_TOP_Y + raster = jax.lax.cond( + flag_collected, + lambda r: self.jr.render_at(r, blackout_x, blackout_y, blackout_mask), + lambda r: r, + operand=raster, + ) + return raster, None - # Select sprite based on type_id - # type_id: 0=cherry, 1=balloon, 2=lollypop, 3=ice_cream - def get_sprite_and_mask(type_id): - cherry_result = (self.cherry_base_mask, self.cherry_solid_mask, self.cherry_palette_ids) - balloon_result = (self.balloon_base_mask, self.balloon_solid_mask, self.balloon_palette_ids) - lollypop_result = (self.lollypop_base_mask, self.lollypop_solid_mask, self.lollypop_palette_ids) - ice_cream_result = (self.ice_cream_base_mask, self.ice_cream_solid_mask, self.ice_cream_palette_ids) - - # Use conditional branching to select sprite - result = jax.lax.cond( - type_id == self.consts.COLLECTIBLE_TYPE_CHERRY, - lambda _: cherry_result, - lambda _: jax.lax.cond( - type_id == self.consts.COLLECTIBLE_TYPE_BALLOON, - lambda _: balloon_result, + raster_blackout, _ = jax.lax.scan(render_blackout, raster_flags, jnp.arange(self.consts.NUM_FLAGS)) + + def render_collectible(carry, collectible_idx): + raster = carry + collectible_y = state.collectibles.y[collectible_idx] + collectible_x = state.collectibles.x[collectible_idx] + collectible_active = state.collectibles.active[collectible_idx] + collectible_color_idx = state.collectibles.color_idx[collectible_idx] + collectible_type_id = state.collectibles.type_id[collectible_idx] + screen_y = 105 + (collectible_y - state.player_car.position.y) + is_visible = jnp.logical_and( + jnp.logical_and(screen_y > 25, screen_y < 195), + collectible_active + ) + + def get_sprite_and_mask(type_id): + cherry_result = (self.cherry_base_mask, self.cherry_solid_mask, self.cherry_palette_ids) + balloon_result = (self.balloon_base_mask, self.balloon_solid_mask, self.balloon_palette_ids) + lollypop_result = (self.lollypop_base_mask, self.lollypop_solid_mask, self.lollypop_palette_ids) + ice_cream_result = (self.ice_cream_base_mask, self.ice_cream_solid_mask, self.ice_cream_palette_ids) + return jax.lax.cond( + type_id == self.consts.COLLECTIBLE_TYPE_CHERRY, + lambda _: cherry_result, lambda _: jax.lax.cond( - type_id == self.consts.COLLECTIBLE_TYPE_LOLLYPOP, - lambda _: lollypop_result, - lambda _: ice_cream_result, + type_id == self.consts.COLLECTIBLE_TYPE_BALLOON, + lambda _: balloon_result, + lambda _: jax.lax.cond( + type_id == self.consts.COLLECTIBLE_TYPE_LOLLYPOP, + lambda _: lollypop_result, + lambda _: ice_cream_result, + operand=None, + ), operand=None, ), operand=None, - ), - operand=None, + ) + + base_mask, solid_mask, palette_ids = get_sprite_and_mask(collectible_type_id) + color_id = palette_ids[collectible_color_idx] + colored_mask = jnp.where( + (base_mask != self.jr.TRANSPARENT_ID) & (base_mask != 0), + color_id, + base_mask, ) - return result - - base_mask, solid_mask, palette_ids = get_sprite_and_mask(collectible_type_id) - - # Only colorize inner pixels, keep black edges (palette ID 0 is black) - color_id = palette_ids[collectible_color_idx] - colored_mask = jnp.where( - (base_mask != self.jr.TRANSPARENT_ID) & (base_mask != 0), - color_id, - base_mask, - ) - - # Render collectible if visible - raster = jax.lax.cond( - is_visible, - lambda r: self.jr.render_at(r, collectible_x.astype(jnp.int32), screen_y.astype(jnp.int32), colored_mask), - lambda r: r, - operand=raster, - ) - return raster, None - - raster, _ = jax.lax.scan(render_collectible, raster, jnp.arange(self.consts.MAX_COLLECTIBLES)) + raster = jax.lax.cond( + is_visible, + lambda r: self.jr.render_at(r, collectible_x.astype(jnp.int32), screen_y.astype(jnp.int32), colored_mask), + lambda r: r, + operand=raster, + ) + return raster, None - all_lives_bottom_mask = self.SHAPE_MASKS["all_lives_bottom"] - raster = self.jr.render_at(raster, 10, 195, all_lives_bottom_mask) + raster_collectibles, _ = jax.lax.scan(render_collectible, raster_blackout, jnp.arange(self.consts.MAX_COLLECTIBLES)) - wall_bottom_mask = self.SHAPE_MASKS["tempPointer"] - raster = self.jr.render_at(raster, 140, 25, wall_bottom_mask) + all_lives_bottom_mask = self.SHAPE_MASKS["all_lives_bottom"] + raster_lives = self.jr.render_at(raster_collectibles, 10, 195, all_lives_bottom_mask) - return self.jr.render_from_palette(raster, self.PALETTE) - - def _get_flag_x_on_road(self, y: chex.Array, road_segment: chex.Array, track_corners_x: chex.Array) -> chex.Array: - """Calculate the X position on a road given a Y coordinate and road segment.""" - y1 = self.consts.TRACK_CORNERS_Y[road_segment] - y2 = self.consts.TRACK_CORNERS_Y[road_segment + 1] - x1 = track_corners_x[road_segment] - x2 = track_corners_x[road_segment + 1] - - # Linear interpolation: x = x1 + (y - y1) * (x2 - x1) / (y2 - y1) - t = jax.lax.cond( - y2 != y1, - lambda _: (y - y1) / (y2 - y1), - lambda _: 0.0, + wall_bottom_mask = self.SHAPE_MASKS["tempPointer"] + raster_pointer = self.jr.render_at(raster_lives, 140, 25, wall_bottom_mask) + + return self.jr.render_from_palette(raster_pointer, self.PALETTE) + + base_scene = jax.lax.cond( + hide_world, + lambda _: render_roads_only(), + lambda _: render_full_scene(), operand=None, ) - return x1 + t * (x2 - x1) \ No newline at end of file + + return jax.lax.cond( + hide_world, + lambda _: self.jr.render_from_palette(base_scene, self.PALETTE), + lambda _: render_rest(base_scene), + operand=None, + ) \ No newline at end of file From 53e30a5de52c7e1553ba75c1c5d1733ae9030214 Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sat, 20 Dec 2025 21:33:07 +0100 Subject: [PATCH 20/76] add missing live counter to game --- src/jaxatari/games/jax_upndown.py | 391 +++++++++++++++--------------- 1 file changed, 189 insertions(+), 202 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 2d8de2915..52e5cd88a 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -19,7 +19,6 @@ class UpNDownConstants(NamedTuple): ACTION_REPEAT_PROBS: float = 0.25 MAX_SPEED: int = 6 INITIAL_LIVES: int = 5 - RESPAWN_HIDE_FRAMES: int = 8 JUMP_ARC_HEIGHT: float = 18.0 # Enemy spawning and movement MAX_ENEMY_CARS: int = 8 @@ -71,6 +70,10 @@ class UpNDownConstants(NamedTuple): FLAG_TOP_Y: int = 20 FLAG_BLACKOUT_SIZE: Tuple[int, int] = (14, 14) # Size of blackout square FLAG_COLLECTION_SCORE: int = 75 # Points awarded for collecting a flag + # Life display constants - positions of life cars at the bottom + LIFE_BOTTOM_X_POSITIONS: chex.Array = jnp.array([13, 18, 25, 33, 33]) # X positions for 5 life cars + LIFE_BOTTOM_Y: int = 195 + LIFE_BLACKOUT_SIZE: Tuple[int, int] = (14, 14) # Size of blackout square for lives PICKUP_SCORE: int = 100 # Points awarded for jumping on a pickup truck FLAG_CARRIER_SCORE: int = 125 # Points awarded for jumping on a flag carrier CAMARO_SCORE: int = 150 # Points awarded for jumping on a camaro @@ -147,7 +150,6 @@ class EnemyCars(NamedTuple): class UpNDownState(NamedTuple): score: chex.Array lives: chex.Array - respawn_cooldown: chex.Array difficulty: chex.Array jump_cooldown: chex.Array post_jump_cooldown: chex.Array @@ -1028,13 +1030,6 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: is_on_road = ~is_jumping is_landing = jnp.logical_and(state.jump_cooldown == 1, jump_cooldown == 0) - respawn_cooldown = jax.lax.cond( - state.respawn_cooldown > 0, - lambda _: state.respawn_cooldown - 1, - lambda _: jnp.array(0, dtype=jnp.int32), - operand=None, - ) - updated_player_car = self._advance_player_car( position_x=state.player_car.position.x, position_y=state.player_car.position.y, @@ -1058,7 +1053,6 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: round_started_now = jnp.logical_or(state.round_started, speed_action_taken) next_state = state._replace( - respawn_cooldown=respawn_cooldown, jump_cooldown=jump_cooldown, post_jump_cooldown=post_jump_cooldown, is_jumping=is_jumping, @@ -1353,7 +1347,6 @@ def _respawn_after_collision(self, state: UpNDownState, new_lives: chex.Array) - return UpNDownState( score=state.score, lives=new_lives, - respawn_cooldown=jnp.array(self.consts.RESPAWN_HIDE_FRAMES, dtype=jnp.int32), difficulty=state.difficulty, jump_cooldown=jnp.array(0, dtype=jnp.int32), post_jump_cooldown=jnp.array(0, dtype=jnp.int32), @@ -1511,7 +1504,6 @@ def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: state = UpNDownState( score=0, lives=jnp.array(self.consts.INITIAL_LIVES, dtype=jnp.int32), - respawn_cooldown=jnp.array(0, dtype=jnp.int32), difficulty=self.consts.DIFFICULTIES[0], jump_cooldown=0, post_jump_cooldown=0, @@ -1868,214 +1860,209 @@ def combine(i, acc): raster = jax.lax.fori_loop(0, total_segments, combine, raster) - hide_world = state.respawn_cooldown > 0 - - # During respawn hide, only show the road/background to emulate an initial road state. - def render_roads_only(): - return raster - - def render_full_scene(): - def select_enemy_mask(enemy_type: chex.Array, going_left: chex.Array): - left_mask = self.enemy_left_masks[enemy_type] - right_mask = self.enemy_right_masks[enemy_type] - return jax.lax.cond(going_left, lambda _: left_mask, lambda _: right_mask, operand=None) - - def render_enemy(carry, enemy_idx): - raster = carry - enemy_active = state.enemy_cars.active[enemy_idx] - enemy_x = state.enemy_cars.position.x[enemy_idx] - enemy_y = state.enemy_cars.position.y[enemy_idx] - enemy_type = state.enemy_cars.type[enemy_idx] - direction_x = state.enemy_cars.direction_x[enemy_idx] - screen_y = 105 + (enemy_y - state.player_car.position.y) - is_visible = jnp.logical_and(enemy_active, jnp.logical_and(screen_y > 25, screen_y < 195)) - enemy_mask = select_enemy_mask(enemy_type, direction_x < 0) - - raster = jax.lax.cond( - is_visible, - lambda r: self.jr.render_at(r, enemy_x.astype(jnp.int32), screen_y.astype(jnp.int32), enemy_mask), - lambda r: r, - operand=raster, - ) - return raster, None - - raster_enemies, _ = jax.lax.scan(render_enemy, raster, jnp.arange(self.consts.MAX_ENEMY_CARS)) - - jump_offset = jax.lax.cond( - state.is_jumping, - lambda _: self._jump_arc_offset(state.jump_cooldown), - lambda _: jnp.array(0.0, dtype=jnp.float32), - operand=None, + def select_enemy_mask(enemy_type: chex.Array, going_left: chex.Array): + left_mask = self.enemy_left_masks[enemy_type] + right_mask = self.enemy_right_masks[enemy_type] + return jax.lax.cond(going_left, lambda _: left_mask, lambda _: right_mask, operand=None) + + def render_enemy(carry, enemy_idx): + raster = carry + enemy_active = state.enemy_cars.active[enemy_idx] + enemy_x = state.enemy_cars.position.x[enemy_idx] + enemy_y = state.enemy_cars.position.y[enemy_idx] + enemy_type = state.enemy_cars.type[enemy_idx] + direction_x = state.enemy_cars.direction_x[enemy_idx] + screen_y = 105 + (enemy_y - state.player_car.position.y) + is_visible = jnp.logical_and(enemy_active, jnp.logical_and(screen_y > 25, screen_y < 195)) + enemy_mask = select_enemy_mask(enemy_type, direction_x < 0) + + raster = jax.lax.cond( + is_visible, + lambda r: self.jr.render_at(r, enemy_x.astype(jnp.int32), screen_y.astype(jnp.int32), enemy_mask), + lambda r: r, + operand=raster, ) + return raster, None - player_screen_y = jnp.int32(105 - jump_offset) - player_mask = self.SHAPE_MASKS["player"] - raster_player = self.jr.render_at(raster_enemies, state.player_car.position.x, player_screen_y, player_mask) + raster_enemies, _ = jax.lax.scan(render_enemy, raster, jnp.arange(self.consts.MAX_ENEMY_CARS)) - wall_top_mask = self.SHAPE_MASKS["wall_top"] - raster_wall_top = self.jr.render_at(raster_player, 0, 0, wall_top_mask) + jump_offset = jax.lax.cond( + state.is_jumping, + lambda _: self._jump_arc_offset(state.jump_cooldown), + lambda _: jnp.array(0.0, dtype=jnp.float32), + operand=None, + ) - wall_bottom_mask = self.SHAPE_MASKS["wall_bottom"] - raster_wall_bottom = self.jr.render_at(raster_wall_top, 0, 210 - wall_bottom_mask.shape[0], wall_bottom_mask) + player_screen_y = jnp.int32(105 - jump_offset) + player_mask = self.SHAPE_MASKS["player"] + raster_player = self.jr.render_at(raster_enemies, state.player_car.position.x, player_screen_y, player_mask) - all_flags_top_mask = self.SHAPE_MASKS["all_flags_top"] - raster_flags_top = self.jr.render_at(raster_wall_bottom, 10, 20, all_flags_top_mask) + wall_top_mask = self.SHAPE_MASKS["wall_top"] + raster_wall_top = self.jr.render_at(raster_player, 0, 0, wall_top_mask) - return raster_flags_top + wall_bottom_mask = self.SHAPE_MASKS["wall_bottom"] + raster_wall_bottom = self.jr.render_at(raster_wall_top, 0, 210 - wall_bottom_mask.shape[0], wall_bottom_mask) + all_flags_top_mask = self.SHAPE_MASKS["all_flags_top"] + raster_flags_top = self.jr.render_at(raster_wall_bottom, 10, 20, all_flags_top_mask) - def render_rest(raster_input): - # Render score centered at the top using dedicated score digit sprites - score_digits = self.jr.int_to_digits(state.score, max_digits=self.score_max_digits) - non_zero_mask = score_digits != 0 - has_non_zero = jnp.any(non_zero_mask) - first_non_zero = jnp.argmax(non_zero_mask) - start_index = jax.lax.select(has_non_zero, first_non_zero, self.score_max_digits - 1) - num_to_render = jax.lax.select(has_non_zero, self.score_max_digits - start_index, 1) + # Render score centered at the top using dedicated score digit sprites + score_digits = self.jr.int_to_digits(state.score, max_digits=self.score_max_digits) + non_zero_mask = score_digits != 0 + has_non_zero = jnp.any(non_zero_mask) + first_non_zero = jnp.argmax(non_zero_mask) + start_index = jax.lax.select(has_non_zero, first_non_zero, self.score_max_digits - 1) + num_to_render = jax.lax.select(has_non_zero, self.score_max_digits - start_index, 1) - total_width = num_to_render * self.score_digit_spacing - score_x = self.score_center_x - (total_width // 2) + total_width = num_to_render * self.score_digit_spacing + score_x = self.score_center_x - (total_width // 2) - raster_score = self.jr.render_label_selective( - raster_input, - jnp.int32(score_x), - self.score_render_y, - score_digits, - self.score_digit_masks, - start_index, - num_to_render, - spacing=self.score_digit_spacing, - max_digits_to_render=self.score_max_digits, - ) + raster_score = self.jr.render_label_selective( + raster_flags_top, + jnp.int32(score_x), + self.score_render_y, + score_digits, + self.score_digit_masks, + start_index, + num_to_render, + spacing=self.score_digit_spacing, + max_digits_to_render=self.score_max_digits, + ) - # Render flags on the road - flag_pole_mask = self.SHAPE_MASKS["flag_pole"] - - def render_flag(carry, flag_idx): - raster = carry - flag_y = state.flags.y[flag_idx] - flag_road = state.flags.road[flag_idx] - flag_segment = state.flags.road_segment[flag_idx] - flag_collected = state.flags.collected[flag_idx] - flag_color_idx = state.flags.color_idx[flag_idx] - - flag_x = jax.lax.cond( - flag_road == 0, - lambda _: self._get_x_on_road(flag_y, flag_segment, self.consts.FIRST_TRACK_CORNERS_X), - lambda _: self._get_x_on_road(flag_y, flag_segment, self.consts.SECOND_TRACK_CORNERS_X), - operand=None, - ) - screen_y = 105 + (flag_y - state.player_car.position.y) - is_visible = jnp.logical_and( - jnp.logical_and(screen_y > 25, screen_y < 195), - ~flag_collected - ) - color_id = self.flag_palette_ids[flag_color_idx] - colored_flag_mask = jnp.where( - self.flag_solid_mask, - color_id, - self.flag_base_mask, - ) - raster = jax.lax.cond( - is_visible, - lambda r: self.jr.render_at( - self.jr.render_at(r, flag_x.astype(jnp.int32), screen_y.astype(jnp.int32), colored_flag_mask), - (flag_x + 5).astype(jnp.int32), screen_y.astype(jnp.int32), flag_pole_mask - ), - lambda r: r, - operand=raster, - ) - return raster, None - - raster_flags, _ = jax.lax.scan(render_flag, raster_score, jnp.arange(self.consts.NUM_FLAGS)) - - blackout_mask = self.SHAPE_MASKS["blackout_square"] - - def render_blackout(carry, flag_idx): - raster = carry - flag_collected = state.flags_collected_mask[flag_idx] - blackout_x = self.consts.FLAG_TOP_X_POSITIONS[flag_idx] - blackout_y = self.consts.FLAG_TOP_Y - raster = jax.lax.cond( - flag_collected, - lambda r: self.jr.render_at(r, blackout_x, blackout_y, blackout_mask), - lambda r: r, - operand=raster, - ) - return raster, None + # Render flags on the road + flag_pole_mask = self.SHAPE_MASKS["flag_pole"] + + def render_flag(carry, flag_idx): + raster = carry + flag_y = state.flags.y[flag_idx] + flag_road = state.flags.road[flag_idx] + flag_segment = state.flags.road_segment[flag_idx] + flag_collected = state.flags.collected[flag_idx] + flag_color_idx = state.flags.color_idx[flag_idx] - raster_blackout, _ = jax.lax.scan(render_blackout, raster_flags, jnp.arange(self.consts.NUM_FLAGS)) - - def render_collectible(carry, collectible_idx): - raster = carry - collectible_y = state.collectibles.y[collectible_idx] - collectible_x = state.collectibles.x[collectible_idx] - collectible_active = state.collectibles.active[collectible_idx] - collectible_color_idx = state.collectibles.color_idx[collectible_idx] - collectible_type_id = state.collectibles.type_id[collectible_idx] - screen_y = 105 + (collectible_y - state.player_car.position.y) - is_visible = jnp.logical_and( - jnp.logical_and(screen_y > 25, screen_y < 195), - collectible_active - ) + flag_x = jax.lax.cond( + flag_road == 0, + lambda _: self._get_x_on_road(flag_y, flag_segment, self.consts.FIRST_TRACK_CORNERS_X), + lambda _: self._get_x_on_road(flag_y, flag_segment, self.consts.SECOND_TRACK_CORNERS_X), + operand=None, + ) + screen_y = 105 + (flag_y - state.player_car.position.y) + is_visible = jnp.logical_and( + jnp.logical_and(screen_y > 25, screen_y < 195), + ~flag_collected + ) + color_id = self.flag_palette_ids[flag_color_idx] + colored_flag_mask = jnp.where( + self.flag_solid_mask, + color_id, + self.flag_base_mask, + ) + raster = jax.lax.cond( + is_visible, + lambda r: self.jr.render_at( + self.jr.render_at(r, flag_x.astype(jnp.int32), screen_y.astype(jnp.int32), colored_flag_mask), + (flag_x + 5).astype(jnp.int32), screen_y.astype(jnp.int32), flag_pole_mask + ), + lambda r: r, + operand=raster, + ) + return raster, None + + raster_flags, _ = jax.lax.scan(render_flag, raster_score, jnp.arange(self.consts.NUM_FLAGS)) + + blackout_mask = self.SHAPE_MASKS["blackout_square"] + + def render_blackout(carry, flag_idx): + raster = carry + flag_collected = state.flags_collected_mask[flag_idx] + blackout_x = self.consts.FLAG_TOP_X_POSITIONS[flag_idx] + blackout_y = self.consts.FLAG_TOP_Y + raster = jax.lax.cond( + flag_collected, + lambda r: self.jr.render_at(r, blackout_x, blackout_y, blackout_mask), + lambda r: r, + operand=raster, + ) + return raster, None + + raster_blackout, _ = jax.lax.scan(render_blackout, raster_flags, jnp.arange(self.consts.NUM_FLAGS)) + + def render_collectible(carry, collectible_idx): + raster = carry + collectible_y = state.collectibles.y[collectible_idx] + collectible_x = state.collectibles.x[collectible_idx] + collectible_active = state.collectibles.active[collectible_idx] + collectible_color_idx = state.collectibles.color_idx[collectible_idx] + collectible_type_id = state.collectibles.type_id[collectible_idx] + screen_y = 105 + (collectible_y - state.player_car.position.y) + is_visible = jnp.logical_and( + jnp.logical_and(screen_y > 25, screen_y < 195), + collectible_active + ) - def get_sprite_and_mask(type_id): - cherry_result = (self.cherry_base_mask, self.cherry_solid_mask, self.cherry_palette_ids) - balloon_result = (self.balloon_base_mask, self.balloon_solid_mask, self.balloon_palette_ids) - lollypop_result = (self.lollypop_base_mask, self.lollypop_solid_mask, self.lollypop_palette_ids) - ice_cream_result = (self.ice_cream_base_mask, self.ice_cream_solid_mask, self.ice_cream_palette_ids) - return jax.lax.cond( - type_id == self.consts.COLLECTIBLE_TYPE_CHERRY, - lambda _: cherry_result, + def get_sprite_and_mask(type_id): + cherry_result = (self.cherry_base_mask, self.cherry_solid_mask, self.cherry_palette_ids) + balloon_result = (self.balloon_base_mask, self.balloon_solid_mask, self.balloon_palette_ids) + lollypop_result = (self.lollypop_base_mask, self.lollypop_solid_mask, self.lollypop_palette_ids) + ice_cream_result = (self.ice_cream_base_mask, self.ice_cream_solid_mask, self.ice_cream_palette_ids) + return jax.lax.cond( + type_id == self.consts.COLLECTIBLE_TYPE_CHERRY, + lambda _: cherry_result, + lambda _: jax.lax.cond( + type_id == self.consts.COLLECTIBLE_TYPE_BALLOON, + lambda _: balloon_result, lambda _: jax.lax.cond( - type_id == self.consts.COLLECTIBLE_TYPE_BALLOON, - lambda _: balloon_result, - lambda _: jax.lax.cond( - type_id == self.consts.COLLECTIBLE_TYPE_LOLLYPOP, - lambda _: lollypop_result, - lambda _: ice_cream_result, - operand=None, - ), + type_id == self.consts.COLLECTIBLE_TYPE_LOLLYPOP, + lambda _: lollypop_result, + lambda _: ice_cream_result, operand=None, ), operand=None, - ) - - base_mask, solid_mask, palette_ids = get_sprite_and_mask(collectible_type_id) - color_id = palette_ids[collectible_color_idx] - colored_mask = jnp.where( - (base_mask != self.jr.TRANSPARENT_ID) & (base_mask != 0), - color_id, - base_mask, - ) - raster = jax.lax.cond( - is_visible, - lambda r: self.jr.render_at(r, collectible_x.astype(jnp.int32), screen_y.astype(jnp.int32), colored_mask), - lambda r: r, - operand=raster, + ), + operand=None, ) - return raster, None - - raster_collectibles, _ = jax.lax.scan(render_collectible, raster_blackout, jnp.arange(self.consts.MAX_COLLECTIBLES)) - all_lives_bottom_mask = self.SHAPE_MASKS["all_lives_bottom"] - raster_lives = self.jr.render_at(raster_collectibles, 10, 195, all_lives_bottom_mask) - - wall_bottom_mask = self.SHAPE_MASKS["tempPointer"] - raster_pointer = self.jr.render_at(raster_lives, 140, 25, wall_bottom_mask) - - return self.jr.render_from_palette(raster_pointer, self.PALETTE) + base_mask, solid_mask, palette_ids = get_sprite_and_mask(collectible_type_id) + color_id = palette_ids[collectible_color_idx] + colored_mask = jnp.where( + (base_mask != self.jr.TRANSPARENT_ID) & (base_mask != 0), + color_id, + base_mask, + ) + raster = jax.lax.cond( + is_visible, + lambda r: self.jr.render_at(r, collectible_x.astype(jnp.int32), screen_y.astype(jnp.int32), colored_mask), + lambda r: r, + operand=raster, + ) + return raster, None + + raster_collectibles, _ = jax.lax.scan(render_collectible, raster_blackout, jnp.arange(self.consts.MAX_COLLECTIBLES)) + + all_lives_bottom_mask = self.SHAPE_MASKS["all_lives_bottom"] + raster_lives = self.jr.render_at(raster_collectibles, 10, 195, all_lives_bottom_mask) + + # Black out lost lives (similar to flag blackout) + blackout_mask = self.SHAPE_MASKS["blackout_square"] + lives_lost = self.consts.INITIAL_LIVES - state.lives + + def render_life_blackout(carry, life_idx): + raster = carry + # Black out this life if it has been lost (life_idx < lives_lost) + should_blackout = life_idx < lives_lost + blackout_x = self.consts.LIFE_BOTTOM_X_POSITIONS[life_idx] + blackout_y = self.consts.LIFE_BOTTOM_Y + raster = jax.lax.cond( + should_blackout, + lambda r: self.jr.render_at(r, blackout_x, blackout_y, blackout_mask), + lambda r: r, + operand=raster, + ) + return raster, None + + raster_lives_blackout, _ = jax.lax.scan(render_life_blackout, raster_lives, jnp.arange(self.consts.INITIAL_LIVES)) - base_scene = jax.lax.cond( - hide_world, - lambda _: render_roads_only(), - lambda _: render_full_scene(), - operand=None, - ) + wall_bottom_mask = self.SHAPE_MASKS["tempPointer"] + raster_pointer = self.jr.render_at(raster_lives_blackout, 140, 25, wall_bottom_mask) - return jax.lax.cond( - hide_world, - lambda _: self.jr.render_from_palette(base_scene, self.PALETTE), - lambda _: render_rest(base_scene), - operand=None, - ) \ No newline at end of file + return self.jr.render_from_palette(raster_pointer, self.PALETTE) \ No newline at end of file From c8ab54fadcc78e481ff8b795d3ec0dca2786608f Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sat, 20 Dec 2025 22:30:52 +0100 Subject: [PATCH 21/76] cleanup code --- src/jaxatari/games/jax_upndown.py | 288 ++++++++++++++---------------- 1 file changed, 134 insertions(+), 154 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 52e5cd88a..4933e8cbf 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -1,8 +1,8 @@ -from jax._src.pjit import JitWrapped import os import math from functools import partial from typing import NamedTuple, Tuple + import jax import jax.lax import jax.numpy as jnp @@ -14,9 +14,8 @@ from jaxatari.environment import JaxEnvironment, JAXAtariAction as Action class UpNDownConstants(NamedTuple): - FRAME_SKIP: int = 4 + FRAME_SKIP: int = 4 # Used by AtariWrapper DIFFICULTIES: chex.Array = jnp.array([0, 1, 2, 3, 4, 5]) - ACTION_REPEAT_PROBS: float = 0.25 MAX_SPEED: int = 6 INITIAL_LIVES: int = 5 JUMP_ARC_HEIGHT: float = 18.0 @@ -42,8 +41,12 @@ class UpNDownConstants(NamedTuple): LANDING_TOLERANCE: int = 15 # Pixels tolerance for landing on a road (increased by 5 for off-road landings) LATE_JUMP_COLLISION_FRAMES: int = 2 LANDING_COLLISION_DISTANCE: float = 8.0 # Larger collision distance when landing (for crossings) + GROUND_COLLISION_DISTANCE: float = 3.0 # Tight collision distance for ground collisions LATE_JUMP_ENEMY_SCORE: int = 400 STEEP_ROAD_SPEED_REDUCTION_INTERVAL: int = 8 # Frames between each speed reduction on steep roads + PASSIVE_SCORE_INTERVAL: int = 60 # Steps between passive score awards + PASSIVE_SCORE_AMOUNT: int = 10 # Points awarded for passive scoring + COLLISION_THRESHOLD: float = 5.0 # Distance threshold for flag/collectible collision TRACK_LENGTH: int = 1036 FIRST_TRACK_CORNERS_X: chex.Array = jnp.array([30, 75, 128, 75, 21, 75, 131, 111, 150, 95, 150, 115, 150, 108, 150, 115, 115, 75, 18, 38, 67, 38, 38, 20, 64, 30]) TRACK_CORNERS_Y: chex.Array = jnp.array([0, -40, -98, -155, -203, -268, -327, -347, -382, -467, -525, -565, -597, -625, -670, -705, -738, -788, -838, -862, -898, -925, -950, -972, -1000, -1035]) @@ -52,8 +55,6 @@ class UpNDownConstants(NamedTuple): INITIAL_ROAD_POS_Y: int = 25 # Flag constants - 8 flags with different colors matching the top row NUM_FLAGS: int = 8 - FLAG_SIZE: Tuple[int, int] = (11, 6) # height, width of the flag sprite - FLAG_POLE_SIZE: Tuple[int, int] = (7, 2) # height, width of the pole sprite # Flag colors as RGBA values (matching the top row from left to right) FLAG_COLORS: chex.Array = jnp.array([ [184, 50, 50, 255], # Red @@ -73,14 +74,8 @@ class UpNDownConstants(NamedTuple): # Life display constants - positions of life cars at the bottom LIFE_BOTTOM_X_POSITIONS: chex.Array = jnp.array([13, 18, 25, 33, 33]) # X positions for 5 life cars LIFE_BOTTOM_Y: int = 195 - LIFE_BLACKOUT_SIZE: Tuple[int, int] = (14, 14) # Size of blackout square for lives - PICKUP_SCORE: int = 100 # Points awarded for jumping on a pickup truck - FLAG_CARRIER_SCORE: int = 125 # Points awarded for jumping on a flag carrier - CAMARO_SCORE: int = 150 # Points awarded for jumping on a camaro - TRUCK_SCORE: int = 175 # Points awarded for jumping on a truck # Collectible constants - unified dynamic spawning MAX_COLLECTIBLES: int = 2 # Maximum collectibles that can exist at once (pool of mixed types) - COLLECTIBLE_SIZE: Tuple[int, int] = (8, 8) # height, width of collectible sprite COLLECTIBLE_SPAWN_INTERVAL: int = 200 # Steps between spawn attempts COLLECTIBLE_DESPAWN_DISTANCE: int = 500 # Distance beyond which collectibles despawn # Collectible types (indices for type field) @@ -88,8 +83,8 @@ class UpNDownConstants(NamedTuple): COLLECTIBLE_TYPE_BALLOON: int = 1 COLLECTIBLE_TYPE_LOLLYPOP: int = 2 COLLECTIBLE_TYPE_ICE_CREAM: int = 3 - # Collectible type spawn probabilities (must sum to 100) - COLLECTIBLE_SPAWN_PROBABILITIES: chex.Array = jnp.array([35, 30, 25, 10], dtype=jnp.int32) # Cherry: 40%, Balloon: 20%, Lollypop: 20%, IceCream: 20% + # Collectible type spawn probabilities (cumulative thresholds for random sampling) + COLLECTIBLE_SPAWN_PROBABILITIES: chex.Array = jnp.array([35, 65, 90, 100], dtype=jnp.int32) # Cherry: 35%, Balloon: 30%, Lollypop: 25%, IceCream: 10% # Collectible type scores COLLECTIBLE_SCORES: chex.Array = jnp.array([50, 65, 70, 75], dtype=jnp.int32) # [cherry, balloon, lollypop, ice_cream] # Shared collectible colors @@ -201,6 +196,29 @@ def __init__(self, consts: UpNDownConstants = None, reward_funcs: list[callable] Action.DOWNFIRE, ] self.obs_size = 3*4+1+1 + # Speed dividers for movement timing (indexed by speed level) + self._speed_dividers = jnp.array([0, 1, 2, 4, 8, 16, 16]) + + @partial(jax.jit, static_argnums=(0,)) + def _compute_movement_timing(self, speed: chex.Array, step_counter: chex.Array) -> Tuple[chex.Array, chex.Array, chex.Array, chex.Array]: + """Calculate movement timing parameters based on speed. + + Returns: + Tuple of (move_y, move_x, step_size, speed_sign) + """ + abs_speed = jnp.abs(speed) + speed_index = jnp.minimum(abs_speed, jnp.int32(self._speed_dividers.shape[0] - 1)) + speed_divider = self._speed_dividers[speed_index] + effective_divider = jnp.maximum(1, speed_divider) + period = jnp.maximum(1, 16 // effective_divider) + half_period = jnp.maximum(1, period // 2) + speed_sign = jnp.sign(speed).astype(jnp.float32) + + move_y = jnp.logical_and((step_counter % period) == (half_period % period), speed != 0) + move_x = jnp.logical_and((step_counter % period) == 0, speed != 0) + step_size = jnp.where(speed_index >= self.consts.MAX_SPEED, 1.5, 1.0) + + return move_y, move_x, step_size, speed_sign @partial(jax.jit, static_argnums=(0,)) def _get_slope_and_intercept_from_indices(self, current_road: chex.Array, road_index_A: chex.Array, road_index_B: chex.Array) -> Tuple[chex.Array, chex.Array]: @@ -219,14 +237,6 @@ def _get_slope_and_intercept_from_indices(self, current_road: chex.Array, road_i b = tracky[road_index] - slope * trackx[road_index] return slope, b - @partial(jax.jit, static_argnums=(0,)) - def _getSlopeAndB(self, state: UpNDownState) -> chex.Array: - return self._get_slope_and_intercept_from_indices( - state.player_car.current_road, - state.player_car.road_index_A, - state.player_car.road_index_B, - ) - @partial(jax.jit, static_argnums=(0,)) def _is_on_line_for_position(self, position: EntityPosition, slope: chex.Array, b: chex.Array, player_speed: chex.Array, turn: chex.Array) -> chex.Array: x_step = abs(jnp.subtract(position.y, slope * (position.x) + b)) @@ -279,30 +289,46 @@ def _is_steep_road_segment(self, current_road: chex.Array, road_index_A: chex.Ar ) # A segment is steep if there's no X change (or very small change) return x_diff < 1.0 - - @partial(jax.jit, static_argnums=(0,)) - def _isOnLine(self, state: UpNDownState, player_speed: chex.Array, turn: chex.Array) -> chex.Array: - slope, b = self._getSlopeAndB(state) - return self._is_on_line_for_position(state.player_car.position, slope, b, player_speed, turn) - - @partial(jax.jit, static_argnums=(0,)) - def _landing_in_water(self, state: UpNDownState, new_position_x: chex.Array, new_position_y: chex.Array) -> chex.Array: - road_A_x = ((new_position_y - self.consts.TRACK_CORNERS_Y[state.player_car.road_index_A]) / (self.consts.TRACK_CORNERS_Y[state.player_car.road_index_A+1] - self.consts.TRACK_CORNERS_Y[state.player_car.road_index_A])) * (self.consts.FIRST_TRACK_CORNERS_X[state.player_car.road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_X[state.player_car.road_index_A]) + self.consts.FIRST_TRACK_CORNERS_X[state.player_car.road_index_A] - road_B_x = ((new_position_y - self.consts.TRACK_CORNERS_Y[state.player_car.road_index_B]) / (self.consts.TRACK_CORNERS_Y[state.player_car.road_index_B+1] - self.consts.TRACK_CORNERS_Y[state.player_car.road_index_B])) * (self.consts.SECOND_TRACK_CORNERS_X[state.player_car.road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[state.player_car.road_index_B]) + self.consts.SECOND_TRACK_CORNERS_X[state.player_car.road_index_B] - distance_to_road_A = jnp.abs(new_position_x - road_A_x) - distance_to_road_B = jnp.abs(new_position_x - road_B_x) - landing_in_Water = jnp.logical_and(distance_to_road_A > self.consts.LANDING_TOLERANCE, distance_to_road_B > self.consts.LANDING_TOLERANCE) - between_roads = jnp.logical_and(new_position_x > jnp.minimum(road_A_x, road_B_x), new_position_x < jnp.maximum(road_A_x, road_B_x)) - return landing_in_Water, between_roads, road_A_x, road_B_x @partial(jax.jit, static_argnums=(0,)) - def _landing_in_water_for_indices(self, road_index_A: chex.Array, road_index_B: chex.Array, new_position_x: chex.Array, new_position_y: chex.Array) -> chex.Array: - road_A_x = ((new_position_y - self.consts.TRACK_CORNERS_Y[road_index_A]) / (self.consts.TRACK_CORNERS_Y[road_index_A+1] - self.consts.TRACK_CORNERS_Y[road_index_A])) * (self.consts.FIRST_TRACK_CORNERS_X[road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_X[road_index_A]) + self.consts.FIRST_TRACK_CORNERS_X[road_index_A] - road_B_x = ((new_position_y - self.consts.TRACK_CORNERS_Y[road_index_B]) / (self.consts.TRACK_CORNERS_Y[road_index_B+1] - self.consts.TRACK_CORNERS_Y[road_index_B])) * (self.consts.SECOND_TRACK_CORNERS_X[road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[road_index_B]) + self.consts.SECOND_TRACK_CORNERS_X[road_index_B] + def _check_landing_position( + self, + road_index_A: chex.Array, + road_index_B: chex.Array, + new_position_x: chex.Array, + new_position_y: chex.Array, + ) -> Tuple[chex.Array, chex.Array, chex.Array, chex.Array]: + """Check if a position is valid for landing (on or between roads). + + Returns: + Tuple of (landing_in_water, between_roads, road_A_x, road_B_x) + """ + # Calculate X position on road A at the given Y + y_ratio_A = (new_position_y - self.consts.TRACK_CORNERS_Y[road_index_A]) / ( + self.consts.TRACK_CORNERS_Y[road_index_A + 1] - self.consts.TRACK_CORNERS_Y[road_index_A] + ) + road_A_x = y_ratio_A * ( + self.consts.FIRST_TRACK_CORNERS_X[road_index_A + 1] - self.consts.FIRST_TRACK_CORNERS_X[road_index_A] + ) + self.consts.FIRST_TRACK_CORNERS_X[road_index_A] + + # Calculate X position on road B at the given Y + y_ratio_B = (new_position_y - self.consts.TRACK_CORNERS_Y[road_index_B]) / ( + self.consts.TRACK_CORNERS_Y[road_index_B + 1] - self.consts.TRACK_CORNERS_Y[road_index_B] + ) + road_B_x = y_ratio_B * ( + self.consts.SECOND_TRACK_CORNERS_X[road_index_B + 1] - self.consts.SECOND_TRACK_CORNERS_X[road_index_B] + ) + self.consts.SECOND_TRACK_CORNERS_X[road_index_B] + distance_to_road_A = jnp.abs(new_position_x - road_A_x) distance_to_road_B = jnp.abs(new_position_x - road_B_x) - landing_in_water = jnp.logical_and(distance_to_road_A > self.consts.LANDING_TOLERANCE, distance_to_road_B > self.consts.LANDING_TOLERANCE) - between_roads = jnp.logical_and(new_position_x > jnp.minimum(road_A_x, road_B_x), new_position_x < jnp.maximum(road_A_x, road_B_x)) + landing_in_water = jnp.logical_and( + distance_to_road_A > self.consts.LANDING_TOLERANCE, + distance_to_road_B > self.consts.LANDING_TOLERANCE, + ) + between_roads = jnp.logical_and( + new_position_x > jnp.minimum(road_A_x, road_B_x), + new_position_x < jnp.maximum(road_A_x, road_B_x), + ) return landing_in_water, between_roads, road_A_x, road_B_x @partial(jax.jit, static_argnums=(0,)) @@ -315,7 +341,6 @@ def _advance_player_car( current_road: chex.Array, speed: chex.Array, is_jumping: chex.Array, - is_on_road: chex.Array, step_counter: chex.Array, width: chex.Array, height: chex.Array, @@ -333,15 +358,8 @@ def _advance_player_car( - If between roads: snap to nearest road - If too far from both roads (outside the road area): crash (water) """ - # Speed-based movement timing - dividers = jnp.array([0, 1, 2, 4, 8, 16, 16]) - abs_speed = jnp.abs(speed) - speed_index = jnp.minimum(abs_speed, jnp.int32(dividers.shape[0] - 1)) - speed_divider = dividers[speed_index] - effective_divider = jnp.maximum(1, speed_divider) - period = jnp.maximum(1, 16 // effective_divider) - half_period = jnp.maximum(1, period // 2) - speed_sign = jax.lax.cond(speed != 0, lambda _: jax.lax.abs(speed) / speed, lambda _: jnp.array(0.0), operand=None) + # Calculate movement timing using helper + move_y, move_x, step_size, speed_sign = self._compute_movement_timing(speed, step_counter) # Get slope and intercept for current road slope, b = self._get_slope_and_intercept_from_indices(current_road, road_index_A, road_index_B) @@ -353,14 +371,8 @@ def _advance_player_car( lambda _: self.consts.SECOND_TRACK_CORNERS_X[road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[road_index_B], operand=None, ) - car_direction_x = jax.lax.cond(direction_raw > 0, lambda _: 1, lambda _: -1, operand=None) - - # Movement timing flags - move_y = jnp.logical_and((step_counter % period) == (half_period % period), speed != 0) - move_x = jnp.logical_and((step_counter % period) == 0, speed != 0) - - # Step size (slightly larger at max speed) - step_size = jnp.where(speed_index >= self.consts.MAX_SPEED, 1.5, 1.0) + # Use sign, default to -1 for zero (vertical segments) + car_direction_x = jnp.where(direction_raw == 0, -1, jnp.sign(direction_raw)).astype(jnp.int32) position = EntityPosition(x=position_x, y=position_y, width=width, height=height) @@ -532,14 +544,8 @@ def _advance_car_core( car_type: chex.Array, ) -> Car: """Simplified car advancement for enemy cars (no jumping/landing logic).""" - dividers = jnp.array([0, 1, 2, 4, 8, 16, 16]) - abs_speed = jnp.abs(speed) - speed_index = jnp.minimum(abs_speed, jnp.int32(dividers.shape[0] - 1)) - speed_divider = dividers[speed_index] - effective_divider = jnp.maximum(1, speed_divider) - period = jnp.maximum(1, 16 // effective_divider) - half_period = jnp.maximum(1, period // 2) - speed_sign = jax.lax.cond(speed != 0, lambda _: jax.lax.abs(speed) / speed, lambda _: jnp.array(0.0), operand=None) + # Calculate movement timing using helper + move_y, move_x, step_size, speed_sign = self._compute_movement_timing(speed, step_counter) slope, b = self._get_slope_and_intercept_from_indices(current_road, road_index_A, road_index_B) @@ -549,12 +555,8 @@ def _advance_car_core( lambda _: self.consts.SECOND_TRACK_CORNERS_X[road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[road_index_B], operand=None, ) - car_direction_x = jax.lax.cond(direction_raw > 0, lambda _: 1, lambda _: -1, operand=None) - - move_y = jnp.logical_and((step_counter % period) == (half_period % period), speed != 0) - move_x = jnp.logical_and((step_counter % period) == 0, speed != 0) - - step_size = jnp.where(speed_index >= self.consts.MAX_SPEED, 1.5, 1.0) + # Use sign, default to -1 for zero (vertical segments) + car_direction_x = jnp.where(direction_raw == 0, -1, jnp.sign(direction_raw)).astype(jnp.int32) position = EntityPosition(x=position_x, y=position_y, width=width, height=height) @@ -653,7 +655,7 @@ def check_flag_collision(flag_idx): ) collision = jnp.logical_and( - jnp.logical_and(y_distance < 5, x_distance < 5), #change the distance threshold if needed + jnp.logical_and(y_distance < self.consts.COLLISION_THRESHOLD, x_distance < self.consts.COLLISION_THRESHOLD), jnp.logical_and(same_road, ~flag_collected) ) return collision @@ -720,34 +722,19 @@ def find_inactive_idx(collectibles_in): # Generate random spawn position using fold_in for deterministic randomness base_key = jax.random.PRNGKey(0) key_for_spawn = jax.random.fold_in(base_key, state.step_counter) - key1, key2, key3, key4, key5 = jax.random.split(key_for_spawn, 5) + key1, key2, key3, key4 = jax.random.split(key_for_spawn, 4) y_spawn = jax.random.uniform(key1, minval=-900.0, maxval=-100.0) road_spawn = jnp.array(jax.random.randint(key2, shape=(), minval=0, maxval=2), dtype=jnp.int32) color_spawn = jnp.array(jax.random.randint(key3, shape=(), minval=0, maxval=len(self.consts.COLLECTIBLE_COLORS)), dtype=jnp.int32) - # Randomly select collectible type based on spawn probabilities - # Convert probabilities (%) to cumulative distribution for sampling + # Randomly select collectible type using cumulative probability thresholds + # COLLECTIBLE_SPAWN_PROBABILITIES contains cumulative values: [35, 65, 90, 100] + # Cherry: [0-35), Balloon: [35-65), Lollypop: [65-90), IceCream: [90-100] rand_type = jax.random.uniform(key4, minval=0.0, maxval=100.0) - # Use cumulative probabilities: cherry [0-40], balloon [40-60], lollypop [60-80], ice_cream [80-100] - def select_type(rand_val): - # Returns 0=cherry, 1=balloon, 2=lollypop, 3=ice_cream - type_id = jnp.where( - rand_val < self.consts.COLLECTIBLE_SPAWN_PROBABILITIES[0], - jnp.int32(self.consts.COLLECTIBLE_TYPE_CHERRY), - jnp.where( - rand_val < self.consts.COLLECTIBLE_SPAWN_PROBABILITIES[1], - jnp.int32(self.consts.COLLECTIBLE_TYPE_BALLOON), - jnp.where( - rand_val < self.consts.COLLECTIBLE_SPAWN_PROBABILITIES[2], - jnp.int32(self.consts.COLLECTIBLE_TYPE_LOLLYPOP), - jnp.int32(self.consts.COLLECTIBLE_TYPE_ICE_CREAM) - ) - ) - ) - return type_id - - type_id_spawn = select_type(rand_type) + # Use searchsorted for efficient threshold lookup + type_id_spawn = jnp.searchsorted(self.consts.COLLECTIBLE_SPAWN_PROBABILITIES, rand_type, side='right') + type_id_spawn = jnp.clip(type_id_spawn, 0, 3).astype(jnp.int32) # Calculate X position on road segment_spawn = self._get_road_segment(y_spawn) @@ -761,34 +748,32 @@ def select_type(rand_val): # Create mask for which collectibles to update update_mask = (jnp.arange(self.consts.MAX_COLLECTIBLES) == spawn_idx) & should_spawn & has_inactive_slot - # Update collectibles with proper masking - updated_collectibles = Collectible( - y=jnp.where(update_mask, y_spawn, state.collectibles.y), - x=jnp.where(update_mask, x_spawn, state.collectibles.x), - road=jnp.where(update_mask, road_spawn, state.collectibles.road), - color_idx=jnp.where(update_mask, color_spawn, state.collectibles.color_idx), - type_id=jnp.where(update_mask, type_id_spawn, state.collectibles.type_id), - active=jnp.where(update_mask, True, state.collectibles.active), - ) + # Update collectibles with proper masking - spawn new items + spawned_y = jnp.where(update_mask, y_spawn, state.collectibles.y) + spawned_x = jnp.where(update_mask, x_spawn, state.collectibles.x) + spawned_road = jnp.where(update_mask, road_spawn, state.collectibles.road) + spawned_color_idx = jnp.where(update_mask, color_spawn, state.collectibles.color_idx) + spawned_type_id = jnp.where(update_mask, type_id_spawn, state.collectibles.type_id) + spawned_active = jnp.where(update_mask, True, state.collectibles.active) # Despawn logic - remove collectibles too far from player def check_despawn(idx): - c_y = updated_collectibles.y[idx] - c_active = updated_collectibles.active[idx] + c_y = spawned_y[idx] + c_active = spawned_active[idx] distance = jnp.abs(new_player_y - c_y) too_far = distance > self.consts.COLLECTIBLE_DESPAWN_DISTANCE should_despawn = jnp.logical_and(c_active, too_far) return should_despawn despawn_mask = jax.vmap(check_despawn)(jnp.arange(self.consts.MAX_COLLECTIBLES)) - new_active = jnp.logical_and(updated_collectibles.active, ~despawn_mask) + active_after_despawn = jnp.logical_and(spawned_active, ~despawn_mask) # Collision detection def check_collision(idx): - c_y = updated_collectibles.y[idx] - c_x = updated_collectibles.x[idx] - c_road = updated_collectibles.road[idx] - c_active = updated_collectibles.active[idx] + c_y = spawned_y[idx] + c_x = spawned_x[idx] + c_road = spawned_road[idx] + c_active = spawned_active[idx] y_distance = jnp.abs(new_player_y - c_y) x_distance = jnp.abs(player_x - c_x) @@ -798,7 +783,7 @@ def check_collision(idx): ) collision = jnp.logical_and( - jnp.logical_and(y_distance < 5, x_distance < 5), + jnp.logical_and(y_distance < self.consts.COLLISION_THRESHOLD, x_distance < self.consts.COLLISION_THRESHOLD), jnp.logical_and(same_road, c_active) ) return collision @@ -806,12 +791,12 @@ def check_collision(idx): collections = jax.vmap(check_collision)(jnp.arange(self.consts.MAX_COLLECTIBLES)) # Deactivate collected items - new_active = jnp.logical_and(new_active, ~collections) + final_active = jnp.logical_and(active_after_despawn, ~collections) # Update score - use type_id to look up score value def get_collection_score(idx): is_collected = collections[idx] - type_id = updated_collectibles.type_id[idx] + type_id = spawned_type_id[idx] # Look up score based on type_id using array indexing score = self.consts.COLLECTIBLE_SCORES[type_id] return jnp.where(is_collected, score, 0) @@ -819,13 +804,14 @@ def get_collection_score(idx): score_array = jax.vmap(get_collection_score)(jnp.arange(self.consts.MAX_COLLECTIBLES)) score_delta = jnp.sum(score_array) + # Create final collectibles state updated_collectibles = Collectible( - y=updated_collectibles.y, - x=updated_collectibles.x, - road=updated_collectibles.road, - color_idx=updated_collectibles.color_idx, - type_id=updated_collectibles.type_id, - active=new_active, + y=spawned_y, + x=spawned_x, + road=spawned_road, + color_idx=spawned_color_idx, + type_id=spawned_type_id, + active=final_active, ) return updated_collectibles, score_delta, new_collectible_timer @@ -1038,7 +1024,6 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: current_road=state.player_car.current_road, speed=player_speed, is_jumping=is_jumping, - is_on_road=is_on_road, step_counter=state.step_counter, width=state.player_car.position.width, height=state.player_car.position.height, @@ -1383,10 +1368,9 @@ def _enemy_collision_step_main(self, state: UpNDownState) -> UpNDownState: dy = jnp.abs(state.enemy_cars.position.y - player_y) wrapped_dy = jnp.minimum(dy, self.consts.TRACK_LENGTH - dy) - # For ground collision: only trigger when enemy position is within 3 pixels - ground_collision_distance = 3.0 - overlap_x_ground = dx <= ground_collision_distance - overlap_y_ground = wrapped_dy <= ground_collision_distance + # For ground collision: only trigger when enemy position is within tight distance + overlap_x_ground = dx <= self.consts.GROUND_COLLISION_DISTANCE + overlap_y_ground = wrapped_dy <= self.consts.GROUND_COLLISION_DISTANCE # For landing collision: use larger distance and road-independent (for crossings) overlap_x_landing = dx <= self.consts.LANDING_COLLISION_DISTANCE overlap_y_landing = wrapped_dy <= self.consts.LANDING_COLLISION_DISTANCE @@ -1454,10 +1438,13 @@ def handle_ground_collision(): ) def _passive_score_step_main(self, state: UpNDownState) -> UpNDownState: - """Award passive score every 60 steps after the player has started moving.""" + """Award passive score at regular intervals after the player has started moving.""" bonus = jax.lax.cond( - jnp.logical_and(state.round_started, state.movement_steps % 60 == 0), - lambda _: jnp.int32(10), + jnp.logical_and( + state.round_started, + state.movement_steps % self.consts.PASSIVE_SCORE_INTERVAL == 0, + ), + lambda _: jnp.int32(self.consts.PASSIVE_SCORE_AMOUNT), lambda _: jnp.int32(0), operand=None, ) @@ -1695,22 +1682,21 @@ def _pad_mask(mask): self.flag_solid_mask = self.flag_base_mask != self.jr.TRANSPARENT_ID self.flag_palette_ids = self._compute_flag_palette_ids() - # Precompute collectible mask data for recoloring (unified for all types: cherry, balloon, lollypop, ice cream) + # Precompute collectible mask data for recoloring (unified for all types) + # Reuse the same palette IDs since all collectibles use FLAG_COLORS + self.collectible_palette_ids = self.flag_palette_ids + self.cherry_base_mask = self.SHAPE_MASKS["cherry"] self.cherry_solid_mask = self.cherry_base_mask != self.jr.TRANSPARENT_ID - self.cherry_palette_ids = self._compute_flag_palette_ids() self.balloon_base_mask = self.SHAPE_MASKS["balloon"] self.balloon_solid_mask = self.balloon_base_mask != self.jr.TRANSPARENT_ID - self.balloon_palette_ids = self._compute_flag_palette_ids() self.lollypop_base_mask = self.SHAPE_MASKS["lollypop"] self.lollypop_solid_mask = self.lollypop_base_mask != self.jr.TRANSPARENT_ID - self.lollypop_palette_ids = self._compute_flag_palette_ids() self.ice_cream_base_mask = self.SHAPE_MASKS["ice_cream"] self.ice_cream_solid_mask = self.ice_cream_base_mask != self.jr.TRANSPARENT_ID - self.ice_cream_palette_ids = self._compute_flag_palette_ids() # Score rendering helpers self.score_digit_masks = self.SHAPE_MASKS["score_digits"] @@ -1736,7 +1722,6 @@ def _get_road_sprite_sizes(self, road_files: list[str]) -> list: sprite = jnp.load(f"{road_dir}/{sprite_name}") sizes.append(sprite.shape[0]) complete_size = int(sum(sizes)) - jax.debug.print("Complete road size: {}", complete_size) return sizes, complete_size def _find_palette_id(self, rgba: jnp.ndarray) -> int: @@ -2001,25 +1986,20 @@ def render_collectible(carry, collectible_idx): ) def get_sprite_and_mask(type_id): - cherry_result = (self.cherry_base_mask, self.cherry_solid_mask, self.cherry_palette_ids) - balloon_result = (self.balloon_base_mask, self.balloon_solid_mask, self.balloon_palette_ids) - lollypop_result = (self.lollypop_base_mask, self.lollypop_solid_mask, self.lollypop_palette_ids) - ice_cream_result = (self.ice_cream_base_mask, self.ice_cream_solid_mask, self.ice_cream_palette_ids) - return jax.lax.cond( - type_id == self.consts.COLLECTIBLE_TYPE_CHERRY, - lambda _: cherry_result, - lambda _: jax.lax.cond( - type_id == self.consts.COLLECTIBLE_TYPE_BALLOON, - lambda _: balloon_result, - lambda _: jax.lax.cond( - type_id == self.consts.COLLECTIBLE_TYPE_LOLLYPOP, - lambda _: lollypop_result, - lambda _: ice_cream_result, - operand=None, - ), - operand=None, - ), - operand=None, + # Use switch for O(1) lookup instead of nested conditionals + def get_cherry(_): + return (self.cherry_base_mask, self.cherry_solid_mask, self.collectible_palette_ids) + def get_balloon(_): + return (self.balloon_base_mask, self.balloon_solid_mask, self.collectible_palette_ids) + def get_lollypop(_): + return (self.lollypop_base_mask, self.lollypop_solid_mask, self.collectible_palette_ids) + def get_ice_cream(_): + return (self.ice_cream_base_mask, self.ice_cream_solid_mask, self.collectible_palette_ids) + + return jax.lax.switch( + type_id, + [get_cherry, get_balloon, get_lollypop, get_ice_cream], + None, ) base_mask, solid_mask, palette_ids = get_sprite_and_mask(collectible_type_id) From 735e152537b60cfb4c94b49d7cd5d857760da03d Mon Sep 17 00:00:00 2001 From: shaik05 Date: Sun, 21 Dec 2025 11:06:49 +0100 Subject: [PATCH 22/76] modified respawn --- src/jaxatari/games/jax_upndown.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 4933e8cbf..311b0b47d 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -19,6 +19,10 @@ class UpNDownConstants(NamedTuple): MAX_SPEED: int = 6 INITIAL_LIVES: int = 5 JUMP_ARC_HEIGHT: float = 18.0 + RESPAWN_DELAY_FRAMES: int = 60 + RESPAWN_Y: int = 0 + RESPAWN_X: int = 30 + ALL_FLAGS_BONUS: int = 1000 # Enemy spawning and movement MAX_ENEMY_CARS: int = 8 ENEMY_SPAWN_INTERVAL: int = 80 @@ -144,7 +148,6 @@ class EnemyCars(NamedTuple): class UpNDownState(NamedTuple): score: chex.Array - lives: chex.Array difficulty: chex.Array jump_cooldown: chex.Array post_jump_cooldown: chex.Array @@ -1332,6 +1335,8 @@ def _respawn_after_collision(self, state: UpNDownState, new_lives: chex.Array) - return UpNDownState( score=state.score, lives=new_lives, + is_dead=jnp.array(False), + respawn_timer=jnp.array(0, dtype=jnp.int32), difficulty=state.difficulty, jump_cooldown=jnp.array(0, dtype=jnp.int32), post_jump_cooldown=jnp.array(0, dtype=jnp.int32), @@ -1491,6 +1496,8 @@ def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: state = UpNDownState( score=0, lives=jnp.array(self.consts.INITIAL_LIVES, dtype=jnp.int32), + is_dead=jnp.array(False), + respawn_timer=jnp.array(0, dtype=jnp.int32), difficulty=self.consts.DIFFICULTIES[0], jump_cooldown=0, post_jump_cooldown=0, @@ -2007,6 +2014,7 @@ def get_ice_cream(_): colored_mask = jnp.where( (base_mask != self.jr.TRANSPARENT_ID) & (base_mask != 0), color_id, + base_mask, ) raster = jax.lax.cond( From a77c31c2711e76204d1ad5f7ba91236a9509f7ad Mon Sep 17 00:00:00 2001 From: shaik05 Date: Fri, 6 Mar 2026 13:15:02 +0100 Subject: [PATCH 23/76] Allow backward jumping and remove steep road mechanics in UpNDown --- src/jaxatari/games/jax_upndown.py | 98 ++++++++++++++++++++++++++++--- 1 file changed, 91 insertions(+), 7 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 311b0b47d..5d988acf5 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -47,7 +47,11 @@ class UpNDownConstants(NamedTuple): LANDING_COLLISION_DISTANCE: float = 8.0 # Larger collision distance when landing (for crossings) GROUND_COLLISION_DISTANCE: float = 3.0 # Tight collision distance for ground collisions LATE_JUMP_ENEMY_SCORE: int = 400 - STEEP_ROAD_SPEED_REDUCTION_INTERVAL: int = 8 # Frames between each speed reduction on steep roads + STEEP_ROAD_SPEED_REDUCTION_INTERVAL: int = 12 # Frames between each speed reduction on steep roads + STEEP_ROAD_MIN_SPEED: float = -2.0 # Minimum speed on steep roads + STEEP_ROAD_JUMP_BOOST: float = 1.5 # Multiplier for jump height on steep roads + STEEP_ROAD_RECOVERY_BOOST: float = 0.8 # Speed boost after leaving steep road + STEEP_ROAD_COOLDOWN: int = 5 PASSIVE_SCORE_INTERVAL: int = 60 # Steps between passive score awards PASSIVE_SCORE_AMOUNT: int = 10 # Points awarded for passive scoring COLLISION_THRESHOLD: float = 5.0 # Distance threshold for flag/collectible collision @@ -209,8 +213,8 @@ def _compute_movement_timing(self, speed: chex.Array, step_counter: chex.Array) Returns: Tuple of (move_y, move_x, step_size, speed_sign) """ - abs_speed = jnp.abs(speed) - speed_index = jnp.minimum(abs_speed, jnp.int32(self._speed_dividers.shape[0] - 1)) + abs_speed = jnp.abs(speed).astype(jnp.int32) + speed_index = jnp.minimum(abs_speed, self._speed_dividers.shape[0] - 1).astype(jnp.int32) speed_divider = self._speed_dividers[speed_index] effective_divider = jnp.maximum(1, speed_divider) period = jnp.maximum(1, 16 // effective_divider) @@ -222,6 +226,69 @@ def _compute_movement_timing(self, speed: chex.Array, step_counter: chex.Array) step_size = jnp.where(speed_index >= self.consts.MAX_SPEED, 1.5, 1.0) return move_y, move_x, step_size, speed_sign + def _apply_steep_road_penalty( + self, + speed: chex.Array, + is_on_steep_road: chex.Array, + steep_road_timer: chex.Array, + is_jumping: chex.Array, + jump_cooldown: chex.Array, + ) -> Tuple[chex.Array, chex.Array, chex.Array]: + """ + Apply enhanced steep road penalty with perfect balance and edge case handling. + + - Dynamically reduces speed on steep roads when going upward. + - Provides jump boost and recovery for better flow. + - Includes cooldown to prevent rapid reductions. + + Returns: (new_speed, new_timer, jump_boost_multiplier) + """ + going_up = speed > 0 + on_steep_going_up = jnp.logical_and(is_on_steep_road, going_up) + in_cooldown = steep_road_timer < 0 # Negative timer indicates cooldown + + # Increment timer only if not in cooldown and on steep road going up + timer_increment = jax.lax.cond( + jnp.logical_and(on_steep_going_up, jnp.logical_not(in_cooldown)), + lambda _: 1, + lambda _: 0, + operand=None, + ) + new_timer = steep_road_timer + timer_increment + + # Apply reduction when timer reaches interval and not in cooldown + should_reduce = jnp.logical_and( + on_steep_going_up, + jnp.logical_and(new_timer >= self.consts.STEEP_ROAD_SPEED_REDUCTION_INTERVAL, jnp.logical_not(in_cooldown)) + ) + + # Proportional reduction: stronger for higher speeds, with minimum cap + reduction_factor = jnp.maximum(0.05, speed * 0.15) # 5-15% of speed + reduced_speed = jnp.maximum(speed - reduction_factor, self.consts.STEEP_ROAD_MIN_SPEED) + + # Set cooldown after reduction (negative timer) + final_timer = jax.lax.cond( + should_reduce, + lambda _: -self.consts.STEEP_ROAD_COOLDOWN, + lambda _: new_timer, + operand=None, + ) + + # Recovery boost after leaving steep road (not jumping) + just_left_steep = jnp.logical_and(jnp.logical_not(on_steep_going_up), jnp.logical_not(is_jumping)) + recovery_boost = jax.lax.cond(just_left_steep, lambda _: self.consts.STEEP_ROAD_RECOVERY_BOOST, lambda _: 0.0, operand=None) + + # Jump boost if jumping on steep road + jump_boost = jax.lax.cond( + jnp.logical_and(on_steep_going_up, jump_cooldown > 0), + lambda _: self.consts.STEEP_ROAD_JUMP_BOOST, + lambda _: 1.0, + operand=None, + ) + + final_speed = jax.lax.cond(should_reduce, lambda _: reduced_speed + recovery_boost, lambda _: speed + recovery_boost, operand=None) + + return final_speed, final_timer, jump_boost @partial(jax.jit, static_argnums=(0,)) def _get_slope_and_intercept_from_indices(self, current_road: chex.Array, road_index_A: chex.Array, road_index_B: chex.Array) -> Tuple[chex.Array, chex.Array]: @@ -854,7 +921,7 @@ def _death_step(self, state: UpNDownState) -> UpNDownState: x=jnp.array(self.consts.RESPAWN_X, dtype=jnp.float32), y=jnp.array(self.consts.RESPAWN_Y, dtype=jnp.float32), ), - speed=0, + speed=jnp.array(0.0, dtype=jnp.float32), current_road=0, ), lambda _: state.player_car, @@ -872,7 +939,7 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: up = jnp.logical_or(action == Action.UP, action == Action.UPFIRE) down = jnp.logical_or(action == Action.DOWN, action == Action.DOWNFIRE) jump = jnp.logical_or(action == Action.FIRE, jnp.logical_or(action == Action.UPFIRE, action == Action.DOWNFIRE)) - player_speed = state.player_car.speed + player_speed = state.player_car.speed.astype(jnp.float32) player_speed = jax.lax.cond( jnp.logical_and(state.player_car.speed < self.consts.MAX_SPEED, up), @@ -899,6 +966,9 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: # 1. Player is on a steep road section # 2. Player is not jumping # 3. Player has positive speed (going upward) + player_speed, steep_road_timer, jump_boost_multiplier = self._apply_steep_road_penalty( + player_speed, is_on_steep_road, state.steep_road_timer, state.is_jumping, state.jump_cooldown + ) on_steep_going_up = jnp.logical_and( is_on_steep_road, jnp.logical_and( @@ -936,7 +1006,7 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: can_start_jump = jnp.logical_and(state.jump_cooldown == 0, state.post_jump_cooldown == 0) is_jumping = jnp.logical_or( jnp.logical_and(state.is_jumping, state.jump_cooldown > 0), - jnp.logical_and(state.is_on_road, jnp.logical_and(player_speed >= 0, jnp.logical_and(can_start_jump, jump))), + jnp.logical_and(state.is_on_road,jnp.logical_and(can_start_jump, jump)), ) # Detect when a new jump is starting (was not jumping, now is jumping) @@ -1018,6 +1088,7 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: ) is_on_road = ~is_jumping is_landing = jnp.logical_and(state.jump_cooldown == 1, jump_cooldown == 0) + jump_arc_height = self.consts.JUMP_ARC_HEIGHT * jump_boost_multiplier updated_player_car = self._advance_player_car( position_x=state.player_car.position.x, @@ -1324,7 +1395,7 @@ def _respawn_after_collision(self, state: UpNDownState, new_lives: chex.Array) - width=self.consts.PLAYER_SIZE[0], height=self.consts.PLAYER_SIZE[1], ), - speed=jnp.array(0, dtype=jnp.int32), + speed=jnp.array(0.0, dtype=jnp.float32), direction_x=jnp.array(0, dtype=jnp.int32), current_road=respawn_road, road_index_A=start_segment, @@ -1625,6 +1696,19 @@ def __init__(self, consts: UpNDownConstants = None): channels=3, #downscale=(84, 84) ) + def _createBackgroundSprite(self, dimensions: Tuple[int, int]) -> jnp.ndarray: + + height, width = dimensions + # Create a vertical gradient: blue at top, lighter blue at bottom + top_color = jnp.array([135, 206, 235, 255], dtype=jnp.uint8) # Sky blue + bottom_color = jnp.array([173, 216, 230, 255], dtype=jnp.uint8) # Lighter sky blue + + # Linear interpolation for gradient + y_coords = jnp.arange(height, dtype=jnp.float32) / (height - 1) + gradient = jnp.outer(y_coords, bottom_color - top_color) + top_color + gradient = jnp.clip(gradient, 0, 255).astype(jnp.uint8) + + return gradient self.jr = render_utils.JaxRenderingUtils(self.config) background = self._createBackgroundSprite(self.config.game_dimensions) From 369b19f4529d6d11698d163460bb2fd9a86f638e Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sun, 21 Dec 2025 15:38:54 +0100 Subject: [PATCH 24/76] improve code quality --- src/jaxatari/games/jax_upndown.py | 346 +++++++++++++----------------- 1 file changed, 144 insertions(+), 202 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 5d988acf5..7735a0987 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -1,5 +1,4 @@ import os -import math from functools import partial from typing import NamedTuple, Tuple @@ -14,7 +13,7 @@ from jaxatari.environment import JaxEnvironment, JAXAtariAction as Action class UpNDownConstants(NamedTuple): - FRAME_SKIP: int = 4 # Used by AtariWrapper + FRAME_SKIP: int = 4 DIFFICULTIES: chex.Array = jnp.array([0, 1, 2, 3, 4, 5]) MAX_SPEED: int = 6 INITIAL_LIVES: int = 5 @@ -326,12 +325,7 @@ def _get_x_on_road(self, y: chex.Array, road_segment: chex.Array, track_corners_ x2 = track_corners_x[road_segment + 1] # Linear interpolation: x = x1 + (y - y1) * (x2 - x1) / (y2 - y1) - t = jax.lax.cond( - y2 != y1, - lambda _: (y - y1) / (y2 - y1), - lambda _: 0.0, - operand=None, - ) + t = jnp.where(y2 != y1, (y - y1) / (y2 - y1), 0.0) return x1 + t * (x2 - x1) @partial(jax.jit, static_argnums=(0,)) @@ -582,7 +576,7 @@ def _advance_player_car( ) # Wrap Y position for looping track - wrapped_y = -((new_player_y * -1) % 1036) + wrapped_y = -((new_player_y * -1) % self.consts.TRACK_LENGTH) return Car( position=EntityPosition( @@ -654,25 +648,14 @@ def _advance_car_core( operand=None, ) - wrapped_y = -((new_y * -1) % 1036) + wrapped_y = -((new_y * -1) % self.consts.TRACK_LENGTH) # Update road segment indices based on new position segment_from_y = self._get_road_segment(new_y) - # Update road indices to track the current segment - next_road_index_A = jax.lax.cond( - current_road == 0, - lambda _: segment_from_y, - lambda _: road_index_A, - operand=None, - ) - - next_road_index_B = jax.lax.cond( - current_road == 1, - lambda _: segment_from_y, - lambda _: road_index_B, - operand=None, - ) + # Update road indices to track the current segment (use jnp.where for branchless execution) + next_road_index_A = jnp.where(current_road == 0, segment_from_y, road_index_A) + next_road_index_B = jnp.where(current_road == 1, segment_from_y, road_index_B) return Car( position=EntityPosition( @@ -689,6 +672,7 @@ def _advance_car_core( type=car_type, ) + @partial(jax.jit, static_argnums=(0,)) def _flag_step(self, state: UpNDownState, new_player_y: chex.Array, player_x: chex.Array, current_road: chex.Array) -> Tuple[Flag, chex.Array, chex.Array]: """Update flag collection state and score. @@ -748,7 +732,8 @@ def check_flag_collision(flag_idx): ) return new_flags, flag_score, new_flags_collected_mask - + + @partial(jax.jit, static_argnums=(0,)) def _collectible_step(self, state: UpNDownState, new_player_y: chex.Array, player_x: chex.Array, current_road: chex.Array) -> Tuple[Collectible, chex.Array, chex.Array]: """Update collectible spawning, despawning, and collection (unified for all types). @@ -764,30 +749,20 @@ def _collectible_step(self, state: UpNDownState, new_player_y: chex.Array, playe Returns: Tuple of (updated_collectibles, score_delta, new_spawn_timer) """ - # Collectible spawning logic - decrement timer and spawn when ready - new_collectible_timer = jax.lax.cond( + # Collectible spawning logic - decrement timer and spawn when ready (use jnp.where for branchless) + new_collectible_timer = jnp.where( state.collectible_spawn_timer <= 0, - lambda _: self.consts.COLLECTIBLE_SPAWN_INTERVAL, - lambda _: state.collectible_spawn_timer - 1, - operand=None, + self.consts.COLLECTIBLE_SPAWN_INTERVAL, + state.collectible_spawn_timer - 1, ) # Attempt to spawn when timer hits 0 should_spawn = state.collectible_spawn_timer <= 0 - # Find first inactive collectible slot - def find_inactive_idx(collectibles_in): - inactive_mask = ~collectibles_in.active - first_inactive = jnp.argmax(inactive_mask.astype(jnp.int32)) - has_inactive = jnp.any(inactive_mask) - return jax.lax.cond( - has_inactive, - lambda _: first_inactive, - lambda _: jnp.array(0, dtype=jnp.int32), - operand=None, - ), has_inactive - - spawn_idx, has_inactive_slot = find_inactive_idx(state.collectibles) + inactive_mask = ~state.collectibles.active + first_inactive = jnp.argmax(inactive_mask.astype(jnp.int32)) + has_inactive_slot = jnp.any(inactive_mask) + spawn_idx = jnp.where(has_inactive_slot, first_inactive, jnp.array(0, dtype=jnp.int32)) # Generate random spawn position using fold_in for deterministic randomness base_key = jax.random.PRNGKey(0) @@ -806,13 +781,12 @@ def find_inactive_idx(collectibles_in): type_id_spawn = jnp.searchsorted(self.consts.COLLECTIBLE_SPAWN_PROBABILITIES, rand_type, side='right') type_id_spawn = jnp.clip(type_id_spawn, 0, 3).astype(jnp.int32) - # Calculate X position on road + # Calculate X position on road (use jnp.where for branchless) segment_spawn = self._get_road_segment(y_spawn) - x_spawn = jax.lax.cond( + x_spawn = jnp.where( road_spawn == 0, - lambda _: self._get_x_on_road(y_spawn, segment_spawn, self.consts.FIRST_TRACK_CORNERS_X), - lambda _: self._get_x_on_road(y_spawn, segment_spawn, self.consts.SECOND_TRACK_CORNERS_X), - operand=None, + self._get_x_on_road(y_spawn, segment_spawn, self.consts.FIRST_TRACK_CORNERS_X), + self._get_x_on_road(y_spawn, segment_spawn, self.consts.SECOND_TRACK_CORNERS_X), ) # Create mask for which collectibles to update @@ -863,16 +837,9 @@ def check_collision(idx): # Deactivate collected items final_active = jnp.logical_and(active_after_despawn, ~collections) - # Update score - use type_id to look up score value - def get_collection_score(idx): - is_collected = collections[idx] - type_id = spawned_type_id[idx] - # Look up score based on type_id using array indexing - score = self.consts.COLLECTIBLE_SCORES[type_id] - return jnp.where(is_collected, score, 0) - - score_array = jax.vmap(get_collection_score)(jnp.arange(self.consts.MAX_COLLECTIBLES)) - score_delta = jnp.sum(score_array) + # Update score - vectorized lookup without vmap overhead + scores = self.consts.COLLECTIBLE_SCORES[spawned_type_id] + score_delta = jnp.sum(jnp.where(collections, scores, 0)) # Create final collectibles state updated_collectibles = Collectible( @@ -885,74 +852,66 @@ def get_collection_score(idx): ) return updated_collectibles, score_delta, new_collectible_timer + + @partial(jax.jit, static_argnums=(0,)) def _death_step(self, state: UpNDownState) -> UpNDownState: - # Player on water road (index 2 assumed water) + """Handle player death when on water road (index 2).""" + # Player on water road (index 2 assumed water) died = jnp.logical_and( state.player_car.current_road == 2, ~state.is_dead, - ) + ) - lives = jax.lax.cond( + # Use jnp.where for branchless execution + lives = jnp.where(died, state.lives - 1, state.lives) + respawn_timer = jnp.where( died, - lambda _: state.lives - 1, - lambda _: state.lives, - operand=None, - ) - lives = jax.lax.cond( - died, - lambda _: state.lives - 1, - lambda _: state.lives, - operand=None, - ) - respawn_timer = jax.lax.cond( - died, - lambda _: jnp.array(self.consts.RESPAWN_DELAY_FRAMES), - lambda _: jnp.maximum(state.respawn_timer - 1, 0), - operand=None, - ) + jnp.array(self.consts.RESPAWN_DELAY_FRAMES), + jnp.maximum(state.respawn_timer - 1, 0), + ) is_dead = jnp.logical_and( - jnp.logical_or(state.is_dead, died), - respawn_timer > 0) - - player_car = jax.lax.cond( - jnp.logical_and(state.is_dead, respawn_timer == 0), - lambda _: state.player_car._replace( - position=state.player_car.position._replace( - x=jnp.array(self.consts.RESPAWN_X, dtype=jnp.float32), - y=jnp.array(self.consts.RESPAWN_Y, dtype=jnp.float32), - ), - speed=jnp.array(0.0, dtype=jnp.float32), - current_road=0, - ), - lambda _: state.player_car, - operand=None, - ) + jnp.logical_or(state.is_dead, died), + respawn_timer > 0, + ) + + # Respawn player when dead and timer expires + should_respawn = jnp.logical_and(state.is_dead, respawn_timer == 0) + new_position = state.player_car.position._replace( + x=jnp.where(should_respawn, jnp.array(self.consts.RESPAWN_X, dtype=jnp.float32), state.player_car.position.x), + y=jnp.where(should_respawn, jnp.array(self.consts.RESPAWN_Y, dtype=jnp.float32), state.player_car.position.y), + ) + player_car = state.player_car._replace( + position=new_position, + speed=jnp.where(should_respawn, 0, state.player_car.speed), + current_road=jnp.where(should_respawn, 0, state.player_car.current_road), + ) + return state._replace( - lives=lives, - is_dead=is_dead, - respawn_timer=respawn_timer, - player_car=player_car, - ) + lives=lives, + is_dead=is_dead, + respawn_timer=respawn_timer, + player_car=player_car, + ) + @partial(jax.jit, static_argnums=(0,)) def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: up = jnp.logical_or(action == Action.UP, action == Action.UPFIRE) down = jnp.logical_or(action == Action.DOWN, action == Action.DOWNFIRE) jump = jnp.logical_or(action == Action.FIRE, jnp.logical_or(action == Action.UPFIRE, action == Action.DOWNFIRE)) player_speed = state.player_car.speed.astype(jnp.float32) - player_speed = jax.lax.cond( + # Use jnp.where for branchless execution + player_speed = jnp.where( jnp.logical_and(state.player_car.speed < self.consts.MAX_SPEED, up), - lambda s: s + 1, - lambda s: s, - operand=player_speed, + player_speed + 1, + player_speed, ) - player_speed = jax.lax.cond( + player_speed = jnp.where( jnp.logical_and(state.player_car.speed > -self.consts.MAX_SPEED, down), - lambda s: s - 1, - lambda s: s, - operand=player_speed, + player_speed - 1, + player_speed, ) # Check if on a steep road section (no X direction change) and apply speed reduction @@ -976,31 +935,28 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: player_speed > 0 ) ) - # Update steep road timer - increment when on steep road going up, reset otherwise - steep_road_timer = jax.lax.cond( + # Update steep road timer - increment when on steep road going up, reset otherwise (use jnp.where) + steep_road_timer = jnp.where( on_steep_going_up, - lambda _: state.steep_road_timer + 1, - lambda _: jnp.array(0, dtype=jnp.int32), - operand=None, + state.steep_road_timer + 1, + jnp.array(0, dtype=jnp.int32), ) # Only reduce speed when timer reaches the interval threshold should_reduce_speed = jnp.logical_and( on_steep_going_up, steep_road_timer >= self.consts.STEEP_ROAD_SPEED_REDUCTION_INTERVAL ) - # Gradually reduce speed toward -2 when on steep section without jumping - player_speed = jax.lax.cond( + # Gradually reduce speed toward -2 when on steep section without jumping (use jnp.where) + player_speed = jnp.where( should_reduce_speed, - lambda s: jnp.maximum(s - 1, jnp.int32(-2)), - lambda s: s, - operand=player_speed, + jnp.maximum(player_speed - 1, jnp.int32(-2)), + player_speed, ) - # Reset timer after speed reduction - steep_road_timer = jax.lax.cond( + # Reset timer after speed reduction (use jnp.where) + steep_road_timer = jnp.where( should_reduce_speed, - lambda _: jnp.array(0, dtype=jnp.int32), - lambda _: steep_road_timer, - operand=None, + jnp.array(0, dtype=jnp.int32), + steep_road_timer, ) can_start_jump = jnp.logical_and(state.jump_cooldown == 0, state.post_jump_cooldown == 0) @@ -1014,81 +970,63 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: # Calculate jump slope at jump start (X change per Y step) # Uses the road segment slope to follow the road trajectory - road_index = jax.lax.cond( + # Use jnp.where for branchless execution + road_index = jnp.where( state.player_car.current_road == 0, - lambda _: state.player_car.road_index_A, - lambda _: state.player_car.road_index_B, - operand=None, + state.player_car.road_index_A, + state.player_car.road_index_B, ) # Get corner coordinates for the current segment # Segment goes from corner[road_index] to corner[road_index+1] - start_x = jax.lax.cond( + # Use jnp.where for branchless execution + start_x = jnp.where( state.player_car.current_road == 0, - lambda _: self.consts.FIRST_TRACK_CORNERS_X[road_index], - lambda _: self.consts.SECOND_TRACK_CORNERS_X[road_index], - operand=None, + self.consts.FIRST_TRACK_CORNERS_X[road_index], + self.consts.SECOND_TRACK_CORNERS_X[road_index], ) - end_x = jax.lax.cond( + end_x = jnp.where( state.player_car.current_road == 0, - lambda _: self.consts.FIRST_TRACK_CORNERS_X[road_index +1], - lambda _: self.consts.SECOND_TRACK_CORNERS_X[road_index +1], - operand=None, + self.consts.FIRST_TRACK_CORNERS_X[road_index + 1], + self.consts.SECOND_TRACK_CORNERS_X[road_index + 1], ) start_y = self.consts.TRACK_CORNERS_Y[road_index] - end_y = jax.lax.cond( + end_y = jnp.where( jnp.equal(self.consts.FIRST_TRACK_CORNERS_X[road_index + 1], self.consts.FIRST_TRACK_CORNERS_X[road_index + 2]), - lambda _: self.consts.TRACK_CORNERS_Y[road_index + 2], - lambda _: self.consts.TRACK_CORNERS_Y[road_index + 1], - operand=None + self.consts.TRACK_CORNERS_Y[road_index + 2], + self.consts.TRACK_CORNERS_Y[road_index + 1], ) # Calculate slope: how much X changes per unit Y change delta_x = end_x - start_x delta_y = end_y - start_y - # Avoid division by zero for horizontal segments - new_jump_slope = jax.lax.cond( + # Avoid division by zero for horizontal segments (use jnp.where) + new_jump_slope = jnp.where( jnp.abs(delta_y) > 0.001, - lambda _: jnp.float32(delta_x) / jnp.float32(delta_y), - lambda _: jnp.float32(0.0), - operand=None, + jnp.float32(delta_x) / jnp.float32(delta_y), + jnp.float32(0.0), ) - # Lock slope at jump start, keep previous slope during jump - jump_slope = jax.lax.cond( - starting_jump, - lambda _: new_jump_slope, - lambda _: state.jump_slope, - operand=None, - ) + # Lock slope at jump start, keep previous slope during jump (use jnp.where) + jump_slope = jnp.where(starting_jump, new_jump_slope, state.jump_slope) - jump_cooldown = jax.lax.cond( + # Use jnp.where for branchless execution of jump_cooldown + jump_cooldown = jnp.where( state.jump_cooldown > 0, - lambda s: s - 1, - lambda s: jax.lax.cond( - is_jumping, - lambda _: self.consts.JUMP_FRAMES, - lambda _: 0, - operand=None, - ), - operand=state.jump_cooldown, + state.jump_cooldown - 1, + jnp.where(is_jumping, self.consts.JUMP_FRAMES, 0), ) - post_jump_cooldown = jax.lax.cond( - jnp.logical_and(state.jump_cooldown == 1, jump_cooldown == 0), - lambda _: self.consts.POST_JUMP_DELAY, - lambda _: jax.lax.cond( - state.post_jump_cooldown > 0, - lambda s: s - 1, - lambda s: s, - operand=state.post_jump_cooldown, - ), - operand=None, + # Use jnp.where for branchless execution of post_jump_cooldown + is_landing_now = jnp.logical_and(state.jump_cooldown == 1, jump_cooldown == 0) + post_jump_cooldown = jnp.where( + is_landing_now, + self.consts.POST_JUMP_DELAY, + jnp.where(state.post_jump_cooldown > 0, state.post_jump_cooldown - 1, state.post_jump_cooldown), ) is_on_road = ~is_jumping - is_landing = jnp.logical_and(state.jump_cooldown == 1, jump_cooldown == 0) - jump_arc_height = self.consts.JUMP_ARC_HEIGHT * jump_boost_multiplier + is_landing = is_landing_now updated_player_car = self._advance_player_car( position_x=state.player_car.position.x, @@ -1119,12 +1057,7 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: player_car=updated_player_car, step_counter=state.step_counter + 1, round_started=round_started_now, - movement_steps=jax.lax.cond( - round_started_now, - lambda _: state.movement_steps + 1, - lambda _: state.movement_steps, - operand=None, - ), + movement_steps=jnp.where(round_started_now, state.movement_steps + 1, state.movement_steps), steep_road_timer=steep_road_timer, jump_slope=jump_slope, ) @@ -1138,6 +1071,7 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: operand=None, ) + @partial(jax.jit, static_argnums=(0,)) def _flag_step_main(self, state: UpNDownState) -> UpNDownState: """Update flag collection state and score.""" new_player_y = state.player_car.position.y @@ -1153,7 +1087,16 @@ def _flag_step_main(self, state: UpNDownState) -> UpNDownState: flags=new_flags, flags_collected_mask=new_flags_collected_mask, ) - + + @partial(jax.jit, static_argnums=(0,)) + def _completion_bonus_step(self, state: UpNDownState) -> UpNDownState: + """Award bonus when all flags are collected.""" + all_flags_collected = jnp.all(state.flags_collected_mask) + # Use jnp.where for branchless execution + bonus = jnp.where(all_flags_collected, self.consts.ALL_FLAGS_BONUS, 0) + return state._replace(score=state.score + bonus) + + @partial(jax.jit, static_argnums=(0,)) def _collectible_step_main(self, state: UpNDownState) -> UpNDownState: """Update collectible spawning, despawning, and collection.""" new_player_y = state.player_car.position.y @@ -1170,6 +1113,7 @@ def _collectible_step_main(self, state: UpNDownState) -> UpNDownState: collectible_spawn_timer=new_collectible_timer, ) + @partial(jax.jit, static_argnums=(0,)) def _initialize_collectibles(self) -> Collectible: """Return a cleared collectible pool.""" return Collectible( @@ -1247,32 +1191,33 @@ def _enemy_step_main(self, state: UpNDownState) -> UpNDownState: active_count = jnp.sum(active_mask.astype(jnp.int32)) can_spawn = active_count < self.consts.MAX_ENEMY_CARS - spawn_timer = jax.lax.cond( + # Use jnp.where for branchless execution + spawn_timer = jnp.where( jnp.logical_and(state.enemy_spawn_timer <= 0, can_spawn), - lambda _: self.consts.ENEMY_SPAWN_INTERVAL, - lambda _: state.enemy_spawn_timer - 1, - operand=None, + self.consts.ENEMY_SPAWN_INTERVAL, + state.enemy_spawn_timer - 1, ) should_spawn = jnp.logical_and(state.enemy_spawn_timer <= 0, can_spawn) inactive_mask = jnp.logical_not(active_mask) first_inactive = jnp.argmax(inactive_mask.astype(jnp.int32)) has_inactive = jnp.any(inactive_mask) - spawn_idx = jax.lax.cond(has_inactive, lambda _: first_inactive, lambda _: jnp.array(0, dtype=jnp.int32), operand=None) + # Use jnp.where for branchless execution + spawn_idx = jnp.where(has_inactive, first_inactive, jnp.array(0, dtype=jnp.int32)) spawn_mask = (jnp.arange(self.consts.MAX_ENEMY_CARS) == spawn_idx) & should_spawn & has_inactive spawn_offset = self.consts.ENEMY_OFFSCREEN_SPAWN_OFFSET + active_count * self.consts.ENEMY_MIN_SPAWN_GAP + jax.random.uniform(key_spawn_offset, minval=0.0, maxval=40.0) spawn_side = jax.random.choice(key_spawn_side, jnp.array([-1.0, 1.0])) raw_spawn_y = state.player_car.position.y + spawn_side * spawn_offset - spawn_y = -(((raw_spawn_y) * -1) % 1036) + spawn_y = -(((raw_spawn_y) * -1) % self.consts.TRACK_LENGTH) spawn_road = jax.random.randint(key_spawn_direction, shape=(), minval=0, maxval=2) segment_spawn = self._get_road_segment(spawn_y) - spawn_x = jax.lax.cond( + # Use jnp.where for branchless execution + spawn_x = jnp.where( spawn_road == 0, - lambda _: self._get_x_on_road(spawn_y, segment_spawn, self.consts.FIRST_TRACK_CORNERS_X), - lambda _: self._get_x_on_road(spawn_y, segment_spawn, self.consts.SECOND_TRACK_CORNERS_X), - operand=None, + self._get_x_on_road(spawn_y, segment_spawn, self.consts.FIRST_TRACK_CORNERS_X), + self._get_x_on_road(spawn_y, segment_spawn, self.consts.SECOND_TRACK_CORNERS_X), ) spawn_speed_mag = jax.random.randint(key_spawn_speed, shape=(), minval=self.consts.ENEMY_SPEED_MIN, maxval=self.consts.ENEMY_SPEED_MAX + 1) @@ -1280,13 +1225,13 @@ def _enemy_step_main(self, state: UpNDownState) -> UpNDownState: spawn_speed = spawn_speed_mag * spawn_speed_sign spawn_type = jax.random.randint(key_spawn_type, shape=(), minval=0, maxval=4) - direction_raw = jax.lax.cond( + # Use jnp.where for branchless execution + direction_raw = jnp.where( spawn_road == 0, - lambda _: self.consts.FIRST_TRACK_CORNERS_X[segment_spawn+1] - self.consts.FIRST_TRACK_CORNERS_X[segment_spawn], - lambda _: self.consts.SECOND_TRACK_CORNERS_X[segment_spawn+1] - self.consts.SECOND_TRACK_CORNERS_X[segment_spawn], - operand=None, + self.consts.FIRST_TRACK_CORNERS_X[segment_spawn+1] - self.consts.FIRST_TRACK_CORNERS_X[segment_spawn], + self.consts.SECOND_TRACK_CORNERS_X[segment_spawn+1] - self.consts.SECOND_TRACK_CORNERS_X[segment_spawn], ) - spawn_direction_x = jax.lax.cond(direction_raw > 0, lambda _: 1, lambda _: -1, operand=None) + spawn_direction_x = jnp.where(direction_raw > 0, 1, -1) enemy_position_x = jnp.where(spawn_mask, spawn_x, state.enemy_cars.position.x) enemy_position_y = jnp.where(spawn_mask, spawn_y, state.enemy_cars.position.y) @@ -1338,7 +1283,7 @@ def _enemy_step_main(self, state: UpNDownState) -> UpNDownState: enemy_age = jnp.where(enemy_active, enemy_age + 1, enemy_age) delta_y = moved_position_y - state.player_car.position.y - wrapped_dist = jnp.minimum(jnp.abs(delta_y), 1036 - jnp.abs(delta_y)) + wrapped_dist = jnp.minimum(jnp.abs(delta_y), self.consts.TRACK_LENGTH - jnp.abs(delta_y)) far_mask = wrapped_dist > self.consts.ENEMY_DESPAWN_DISTANCE age_mask = enemy_age > self.consts.ENEMY_MAX_AGE despawn_mask = jnp.logical_and(enemy_active, jnp.logical_or(far_mask, age_mask)) @@ -1513,17 +1458,14 @@ def handle_ground_collision(): operand=None, ) + @partial(jax.jit, static_argnums=(0,)) def _passive_score_step_main(self, state: UpNDownState) -> UpNDownState: """Award passive score at regular intervals after the player has started moving.""" - bonus = jax.lax.cond( - jnp.logical_and( - state.round_started, - state.movement_steps % self.consts.PASSIVE_SCORE_INTERVAL == 0, - ), - lambda _: jnp.int32(self.consts.PASSIVE_SCORE_AMOUNT), - lambda _: jnp.int32(0), - operand=None, + should_award = jnp.logical_and( + state.round_started, + state.movement_steps % self.consts.PASSIVE_SCORE_INTERVAL == 0, ) + bonus = jnp.where(should_award, jnp.int32(self.consts.PASSIVE_SCORE_AMOUNT), jnp.int32(0)) return state._replace(score=state.score + bonus) @@ -1733,7 +1675,7 @@ def _createBackgroundSprite(self, dimensions: Tuple[int, int]) -> jnp.ndarray: self.view_height = self.config.game_dimensions[0] # Precompute offsets so repeated road tiles can wrap seamlessly without gaps. road_cycle = max(1, self.complete_road_size) - repeats = max(1, math.ceil(self.view_height / road_cycle) + 2) + repeats = max(1, int(-(-self.view_height // road_cycle)) + 2) # Ceiling division trick self._road_tile_offsets = jnp.arange(-repeats, repeats + 1, dtype=jnp.int32) * jnp.int32(self.complete_road_size) self._num_road_tiles = int(self._road_tile_offsets.shape[0]) From aa66641f2498d460d58164fd625e8e235f108d3c Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sun, 21 Dec 2025 15:53:34 +0100 Subject: [PATCH 25/76] improve enemy spawning --- src/jaxatari/games/jax_upndown.py | 73 ++++++++++++++++++++++--------- 1 file changed, 53 insertions(+), 20 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 7735a0987..b94c73a8a 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -24,17 +24,21 @@ class UpNDownConstants(NamedTuple): ALL_FLAGS_BONUS: int = 1000 # Enemy spawning and movement MAX_ENEMY_CARS: int = 8 - ENEMY_SPAWN_INTERVAL: int = 80 - ENEMY_DESPAWN_DISTANCE: int = 300 + ENEMY_SPAWN_INTERVAL_BASE: int = 30 # Base spawn interval + ENEMY_SPAWN_INTERVAL_MAX: int = 60 # Max spawn interval when many enemies exist + ENEMY_MIN_VISIBLE_COUNT: int = 2 # Minimum enemies to keep on screen + ENEMY_VISIBLE_DISTANCE: int = 120 # Distance within which enemies are considered "visible" + ENEMY_DESPAWN_DISTANCE: int = 250 ENEMY_SPEED_MIN: int = 3 ENEMY_SPEED_MAX: int = 5 ENEMY_DIRECTION_SWITCH_PROB: float = 0.0001 - ENEMY_OFFSCREEN_SPAWN_OFFSET: float = 100.0 - ENEMY_MIN_SPAWN_GAP: float = 30.0 - ENEMY_MAX_AGE: int = 1900 + ENEMY_SPAWN_OFFSET_MIN: float = 70.0 # Closer spawn distance + ENEMY_SPAWN_OFFSET_MAX: float = 130.0 # Max spawn offset + ENEMY_MIN_SPAWN_GAP: float = 25.0 # Reduced gap between spawns + ENEMY_MAX_AGE: int = 1900 INITIAL_ENEMY_COUNT: int = 4 - INITIAL_ENEMY_BASE_OFFSET: float = 40.0 - INITIAL_ENEMY_GAP: float = 30.0 + INITIAL_ENEMY_BASE_OFFSET: float = 35.0 # Closer initial enemies + INITIAL_ENEMY_GAP: float = 25.0 # Tighter initial spacing ENEMY_TYPE_CAMERO: int = 0 ENEMY_TYPE_FLAG_CARRIER: int = 1 ENEMY_TYPE_PICKUP: int = 2 @@ -82,7 +86,7 @@ class UpNDownConstants(NamedTuple): LIFE_BOTTOM_X_POSITIONS: chex.Array = jnp.array([13, 18, 25, 33, 33]) # X positions for 5 life cars LIFE_BOTTOM_Y: int = 195 # Collectible constants - unified dynamic spawning - MAX_COLLECTIBLES: int = 2 # Maximum collectibles that can exist at once (pool of mixed types) + MAX_COLLECTIBLES: int = 1 # Maximum collectibles that can exist at once (pool of mixed types) COLLECTIBLE_SPAWN_INTERVAL: int = 200 # Steps between spawn attempts COLLECTIBLE_DESPAWN_DISTANCE: int = 500 # Distance beyond which collectibles despawn # Collectible types (indices for type field) @@ -1182,7 +1186,7 @@ def init_direction(seg, road): @partial(jax.jit, static_argnums=(0,)) def _enemy_step_main(self, state: UpNDownState) -> UpNDownState: - """Spawn and move enemy cars that share the player's road logic.""" + """Spawn and move enemy cars with adaptive spawning for consistent enemy presence.""" base_key = jax.random.PRNGKey(2025) step_key = jax.random.fold_in(base_key, state.step_counter) key_spawn_offset, key_spawn_side, key_spawn_speed, key_spawn_direction, key_spawn_type, key_spawn_sign, key_flip_root = jax.random.split(step_key, 7) @@ -1191,29 +1195,59 @@ def _enemy_step_main(self, state: UpNDownState) -> UpNDownState: active_count = jnp.sum(active_mask.astype(jnp.int32)) can_spawn = active_count < self.consts.MAX_ENEMY_CARS - # Use jnp.where for branchless execution + # Calculate how many enemies are "visible" (within visible distance of player) + player_y = state.player_car.position.y + enemy_distances = jnp.abs(state.enemy_cars.position.y - player_y) + wrapped_distances = jnp.minimum(enemy_distances, self.consts.TRACK_LENGTH - enemy_distances) + visible_mask = jnp.logical_and(active_mask, wrapped_distances < self.consts.ENEMY_VISIBLE_DISTANCE) + visible_count = jnp.sum(visible_mask.astype(jnp.int32)) + + # Adaptive spawn interval: spawn faster when fewer visible enemies + # If below minimum, spawn immediately (interval = 0) + # Otherwise scale between BASE and MAX based on visible count + needs_urgent_spawn = visible_count < self.consts.ENEMY_MIN_VISIBLE_COUNT + spawn_interval = jnp.where( + needs_urgent_spawn, + jnp.int32(0), # Spawn immediately when too few visible + jnp.int32(self.consts.ENEMY_SPAWN_INTERVAL_BASE + + (visible_count * (self.consts.ENEMY_SPAWN_INTERVAL_MAX - self.consts.ENEMY_SPAWN_INTERVAL_BASE)) // + self.consts.MAX_ENEMY_CARS) + ) + + # Spawn when timer expires OR when we urgently need more enemies + timer_expired = state.enemy_spawn_timer <= 0 + should_spawn = jnp.logical_and( + jnp.logical_or(timer_expired, needs_urgent_spawn), + can_spawn + ) + + # Reset timer with adaptive interval spawn_timer = jnp.where( - jnp.logical_and(state.enemy_spawn_timer <= 0, can_spawn), - self.consts.ENEMY_SPAWN_INTERVAL, - state.enemy_spawn_timer - 1, + should_spawn, + spawn_interval, + jnp.maximum(state.enemy_spawn_timer - 1, 0), ) - should_spawn = jnp.logical_and(state.enemy_spawn_timer <= 0, can_spawn) inactive_mask = jnp.logical_not(active_mask) first_inactive = jnp.argmax(inactive_mask.astype(jnp.int32)) has_inactive = jnp.any(inactive_mask) - # Use jnp.where for branchless execution spawn_idx = jnp.where(has_inactive, first_inactive, jnp.array(0, dtype=jnp.int32)) spawn_mask = (jnp.arange(self.consts.MAX_ENEMY_CARS) == spawn_idx) & should_spawn & has_inactive - spawn_offset = self.consts.ENEMY_OFFSCREEN_SPAWN_OFFSET + active_count * self.consts.ENEMY_MIN_SPAWN_GAP + jax.random.uniform(key_spawn_offset, minval=0.0, maxval=40.0) + # Spawn closer when urgent (fewer visible enemies), farther when plenty exist + base_offset = jnp.where( + needs_urgent_spawn, + self.consts.ENEMY_SPAWN_OFFSET_MIN, # Spawn closer when needed + self.consts.ENEMY_SPAWN_OFFSET_MIN + visible_count * 10.0 # Farther when plenty exist + ) + spawn_offset = base_offset + jax.random.uniform(key_spawn_offset, minval=0.0, maxval=30.0) + spawn_side = jax.random.choice(key_spawn_side, jnp.array([-1.0, 1.0])) raw_spawn_y = state.player_car.position.y + spawn_side * spawn_offset spawn_y = -(((raw_spawn_y) * -1) % self.consts.TRACK_LENGTH) spawn_road = jax.random.randint(key_spawn_direction, shape=(), minval=0, maxval=2) segment_spawn = self._get_road_segment(spawn_y) - # Use jnp.where for branchless execution spawn_x = jnp.where( spawn_road == 0, self._get_x_on_road(spawn_y, segment_spawn, self.consts.FIRST_TRACK_CORNERS_X), @@ -1225,7 +1259,6 @@ def _enemy_step_main(self, state: UpNDownState) -> UpNDownState: spawn_speed = spawn_speed_mag * spawn_speed_sign spawn_type = jax.random.randint(key_spawn_type, shape=(), minval=0, maxval=4) - # Use jnp.where for branchless execution direction_raw = jnp.where( spawn_road == 0, self.consts.FIRST_TRACK_CORNERS_X[segment_spawn+1] - self.consts.FIRST_TRACK_CORNERS_X[segment_spawn], @@ -1369,7 +1402,7 @@ def _respawn_after_collision(self, state: UpNDownState, new_lives: chex.Array) - collectibles=collectibles, collectible_spawn_timer=jnp.array(0, dtype=jnp.int32), enemy_cars=enemy_cars, - enemy_spawn_timer=jnp.array(self.consts.ENEMY_SPAWN_INTERVAL, dtype=jnp.int32), + enemy_spawn_timer=jnp.array(self.consts.ENEMY_SPAWN_INTERVAL_BASE, dtype=jnp.int32), ) @partial(jax.jit, static_argnums=(0,)) @@ -1540,7 +1573,7 @@ def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: collectibles=collectibles, collectible_spawn_timer=jnp.array(0, dtype=jnp.int32), enemy_cars=enemy_cars, - enemy_spawn_timer=jnp.array(self.consts.ENEMY_SPAWN_INTERVAL, dtype=jnp.int32), + enemy_spawn_timer=jnp.array(self.consts.ENEMY_SPAWN_INTERVAL_BASE, dtype=jnp.int32), ) initial_obs = self._get_observation(state) return initial_obs, state From e5e3bc1c5918a0017c1fa66aff5d1c5b1f1b098e Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sun, 21 Dec 2025 16:06:09 +0100 Subject: [PATCH 26/76] improve jumping on enemys --- src/jaxatari/games/jax_upndown.py | 34 ++++++++++++++----------------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index b94c73a8a..316011696 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -45,9 +45,9 @@ class UpNDownConstants(NamedTuple): ENEMY_TYPE_TRUCK: int = 3 JUMP_FRAMES: int = 28 POST_JUMP_DELAY: int = 10 - LANDING_TOLERANCE: int = 15 # Pixels tolerance for landing on a road (increased by 5 for off-road landings) + LANDING_TOLERANCE: int = 20 # Pixels tolerance for landing on a road (increased by 5 for wider landing zone) LATE_JUMP_COLLISION_FRAMES: int = 2 - LANDING_COLLISION_DISTANCE: float = 8.0 # Larger collision distance when landing (for crossings) + LANDING_COLLISION_DISTANCE: float = 12.0 # Larger collision distance when landing (increased for easier enemy kills) GROUND_COLLISION_DISTANCE: float = 3.0 # Tight collision distance for ground collisions LATE_JUMP_ENEMY_SCORE: int = 400 STEEP_ROAD_SPEED_REDUCTION_INTERVAL: int = 12 # Frames between each speed reduction on steep roads @@ -1425,34 +1425,30 @@ def _enemy_collision_step_main(self, state: UpNDownState) -> UpNDownState: # For ground collision: only trigger when enemy position is within tight distance overlap_x_ground = dx <= self.consts.GROUND_COLLISION_DISTANCE overlap_y_ground = wrapped_dy <= self.consts.GROUND_COLLISION_DISTANCE - # For landing collision: use larger distance and road-independent (for crossings) - overlap_x_landing = dx <= self.consts.LANDING_COLLISION_DISTANCE - overlap_y_landing = wrapped_dy <= self.consts.LANDING_COLLISION_DISTANCE - # For late jump collision: use original larger overlap based on car dimensions + # For late jump collision: use larger overlap based on car dimensions overlap_x_jump = dx <= (state.player_car.position.width + state.enemy_cars.position.width) / 2.0 overlap_y_jump = wrapped_dy <= (state.player_car.position.height + state.enemy_cars.position.height) / 2.0 same_road = state.enemy_cars.current_road == state.player_car.current_road # Ground collision mask uses tight 3-pixel distance and same road ground_collision_mask = jnp.logical_and(state.enemy_cars.active, jnp.logical_and(same_road, jnp.logical_and(overlap_x_ground, overlap_y_ground))) - # Landing collision mask uses larger distance and is road-independent (for crossings) - landing_collision_mask = jnp.logical_and(state.enemy_cars.active, jnp.logical_and(overlap_x_landing, overlap_y_landing)) - # Jump collision mask uses original larger overlap (for scoring when jumping on enemies) - jump_collision_mask = jnp.logical_and(state.enemy_cars.active, jnp.logical_and(same_road, jnp.logical_and(overlap_x_jump, overlap_y_jump))) + # Jump collision mask is road-independent - can destroy enemies on either road when jumping + jump_collision_mask = jnp.logical_and(state.enemy_cars.active, jnp.logical_and(overlap_x_jump, overlap_y_jump)) collision_mask = jump_collision_mask # For late jump scoring any_jump_collision = jnp.any(jump_collision_mask) any_ground_collision = jnp.any(ground_collision_mask) - any_landing_collision = jnp.any(landing_collision_mask) - # Check if player is in the landing phase (just landed from a jump) - is_landing_phase = jnp.logical_and(state.post_jump_cooldown > 0, state.post_jump_cooldown <= self.consts.POST_JUMP_DELAY) + # Check if player is in post-landing invincibility phase + is_invincible = state.post_jump_cooldown > 0 late_jump_window = jnp.logical_and(state.is_jumping, state.jump_cooldown <= self.consts.LATE_JUMP_COLLISION_FRAMES) late_jump_collision = jnp.logical_and(any_jump_collision, late_jump_window) - grounded_collision = jnp.logical_and(any_ground_collision, jnp.logical_not(state.is_jumping)) - # Landing collision is road-independent and uses larger distance - landing_collision = jnp.logical_and(any_landing_collision, is_landing_phase) + # Ground collision only applies when not jumping AND not in post-landing invincibility + grounded_collision = jnp.logical_and( + any_ground_collision, + jnp.logical_and(jnp.logical_not(state.is_jumping), jnp.logical_not(is_invincible)) + ) def handle_late_jump(): hits = collision_mask.astype(jnp.int32) @@ -1476,8 +1472,8 @@ def handle_late_jump(): def handle_ground_collision(): return self._respawn_after_collision(state, state.lives - 1) - # Check for any collision that should cause respawn (ground or landing) - any_fatal_collision = jnp.logical_or(grounded_collision, landing_collision) + # Ground collision causes respawn (landing is now protected by invincibility) + any_fatal_collision = grounded_collision return jax.lax.cond( late_jump_collision, @@ -1946,7 +1942,7 @@ def render_enemy(carry, enemy_idx): player_screen_y = jnp.int32(105 - jump_offset) player_mask = self.SHAPE_MASKS["player"] - raster_player = self.jr.render_at(raster_enemies, state.player_car.position.x, player_screen_y, player_mask) + raster_player = self.jr.render_at_clipped(raster_enemies, state.player_car.position.x, player_screen_y, player_mask) wall_top_mask = self.SHAPE_MASKS["wall_top"] raster_wall_top = self.jr.render_at(raster_player, 0, 0, wall_top_mask) From 11433f4ac8e20d12fd051f8a202713293afc5f80 Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sun, 21 Dec 2025 16:22:10 +0100 Subject: [PATCH 27/76] reuse movement logic for enemys --- src/jaxatari/games/jax_upndown.py | 331 +++++++++++++----------------- 1 file changed, 144 insertions(+), 187 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 316011696..f1b717075 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -295,19 +295,21 @@ def _apply_steep_road_penalty( @partial(jax.jit, static_argnums=(0,)) def _get_slope_and_intercept_from_indices(self, current_road: chex.Array, road_index_A: chex.Array, road_index_B: chex.Array) -> Tuple[chex.Array, chex.Array]: - trackx, tracky, road_index = jax.lax.cond( - current_road == 0, - lambda _: (self.consts.FIRST_TRACK_CORNERS_X, self.consts.TRACK_CORNERS_Y, road_index_A), - lambda _: (self.consts.SECOND_TRACK_CORNERS_X, self.consts.TRACK_CORNERS_Y, road_index_B), - operand=None, - ) - slope = jax.lax.cond( - trackx[road_index+1] - trackx[road_index] != 0, - lambda _: (tracky[road_index+1] - tracky[road_index]) / (trackx[road_index+1] - trackx[road_index]), - lambda _: 300.0, - operand=None, - ) - b = tracky[road_index] - slope * trackx[road_index] + """Calculate slope and intercept for the current road segment.""" + road_index = jnp.where(current_road == 0, road_index_A, road_index_B) + x1 = jnp.where(current_road == 0, + self.consts.FIRST_TRACK_CORNERS_X[road_index], + self.consts.SECOND_TRACK_CORNERS_X[road_index]) + x2 = jnp.where(current_road == 0, + self.consts.FIRST_TRACK_CORNERS_X[road_index + 1], + self.consts.SECOND_TRACK_CORNERS_X[road_index + 1]) + y1 = self.consts.TRACK_CORNERS_Y[road_index] + y2 = self.consts.TRACK_CORNERS_Y[road_index + 1] + + dx = x2 - x1 + dy = y2 - y1 + slope = jnp.where(dx != 0, dy / dx, 300.0) + b = y1 - slope * x1 return slope, b @partial(jax.jit, static_argnums=(0,)) @@ -332,6 +334,24 @@ def _get_x_on_road(self, y: chex.Array, road_segment: chex.Array, track_corners_ t = jnp.where(y2 != y1, (y - y1) / (y2 - y1), 0.0) return x1 + t * (x2 - x1) + @partial(jax.jit, static_argnums=(0,)) + def _get_x_for_road_index(self, y: chex.Array, road_segment: chex.Array, road_index: chex.Array) -> chex.Array: + """Get X position on road A (index 0) or road B (index 1) for given Y and segment.""" + track_corners = jnp.where( + road_index == 0, + self.consts.FIRST_TRACK_CORNERS_X[road_segment], + self.consts.SECOND_TRACK_CORNERS_X[road_segment], + ) + track_corners_next = jnp.where( + road_index == 0, + self.consts.FIRST_TRACK_CORNERS_X[road_segment + 1], + self.consts.SECOND_TRACK_CORNERS_X[road_segment + 1], + ) + y1 = self.consts.TRACK_CORNERS_Y[road_segment] + y2 = self.consts.TRACK_CORNERS_Y[road_segment + 1] + t = jnp.where(y2 != y1, (y - y1) / (y2 - y1), 0.0) + return track_corners + t * (track_corners_next - track_corners) + @partial(jax.jit, static_argnums=(0,)) def _get_road_segment(self, y: chex.Array) -> chex.Array: """Return the road segment index for a given y position.""" @@ -339,6 +359,56 @@ def _get_road_segment(self, y: chex.Array) -> chex.Array: max_idx = jnp.int32(len(self.consts.TRACK_CORNERS_Y) - 1) return jnp.clip(segments - 1, 0, max_idx) + @partial(jax.jit, static_argnums=(0,)) + def _compute_direction_x(self, current_road: chex.Array, road_index_A: chex.Array, road_index_B: chex.Array) -> chex.Array: + """Calculate the X direction for movement on the current road segment. + + Returns: + Direction as int32: -1 for left, 1 for right (defaults to -1 for vertical segments) + """ + # Select the road index based on which road we're on + road_index = jnp.where(current_road == 0, road_index_A, road_index_B) + # Select corners for the current road + x_curr = jnp.where(current_road == 0, + self.consts.FIRST_TRACK_CORNERS_X[road_index], + self.consts.SECOND_TRACK_CORNERS_X[road_index]) + x_next = jnp.where(current_road == 0, + self.consts.FIRST_TRACK_CORNERS_X[road_index + 1], + self.consts.SECOND_TRACK_CORNERS_X[road_index + 1]) + direction_raw = x_next - x_curr + return jnp.where(direction_raw == 0, -1, jnp.sign(direction_raw)).astype(jnp.int32) + + @partial(jax.jit, static_argnums=(0,)) + def _move_on_road( + self, + position: EntityPosition, + slope: chex.Array, + b: chex.Array, + speed_sign: chex.Array, + step_size: chex.Array, + car_direction_x: chex.Array, + move_y: chex.Array, + move_x: chex.Array, + ) -> Tuple[chex.Array, chex.Array]: + """Move a car on the road based on timing and geometry. + + Returns: + Tuple of (new_x, new_y) positions + """ + new_y = jnp.where( + jnp.logical_and(move_y, self._is_on_line_for_position(position, slope, b, speed_sign, 1)), + position.y + speed_sign * -step_size, + position.y, + ) + + new_x = jnp.where( + jnp.logical_and(move_x, self._is_on_line_for_position(position, slope, b, speed_sign, 2)), + position.x + speed_sign * car_direction_x * step_size, + position.x, + ) + + return new_x, new_y + @partial(jax.jit, static_argnums=(0,)) def _is_steep_road_segment(self, current_road: chex.Array, road_index_A: chex.Array, road_index_B: chex.Array) -> chex.Array: """Check if the current road segment is steep (no X direction change). @@ -349,12 +419,14 @@ def _is_steep_road_segment(self, current_road: chex.Array, road_index_A: chex.Ar Returns True if the segment is steep (requires jump to pass when going up). """ # Get the X difference for the current road segment - x_diff = jax.lax.cond( - current_road == 0, - lambda _: jnp.abs(self.consts.FIRST_TRACK_CORNERS_X[road_index_A + 1] - self.consts.FIRST_TRACK_CORNERS_X[road_index_A]), - lambda _: jnp.abs(self.consts.SECOND_TRACK_CORNERS_X[road_index_B + 1] - self.consts.SECOND_TRACK_CORNERS_X[road_index_B]), - operand=None, - ) + road_index = jnp.where(current_road == 0, road_index_A, road_index_B) + x_curr = jnp.where(current_road == 0, + self.consts.FIRST_TRACK_CORNERS_X[road_index], + self.consts.SECOND_TRACK_CORNERS_X[road_index]) + x_next = jnp.where(current_road == 0, + self.consts.FIRST_TRACK_CORNERS_X[road_index + 1], + self.consts.SECOND_TRACK_CORNERS_X[road_index + 1]) + x_diff = jnp.abs(x_next - x_curr) # A segment is steep if there's no X change (or very small change) return x_diff < 1.0 @@ -433,57 +505,26 @@ def _advance_player_car( slope, b = self._get_slope_and_intercept_from_indices(current_road, road_index_A, road_index_B) # Determine X direction based on current road segment (for normal movement) - direction_raw = jax.lax.cond( - current_road == 0, - lambda _: self.consts.FIRST_TRACK_CORNERS_X[road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_X[road_index_A], - lambda _: self.consts.SECOND_TRACK_CORNERS_X[road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[road_index_B], - operand=None, - ) - # Use sign, default to -1 for zero (vertical segments) - car_direction_x = jnp.where(direction_raw == 0, -1, jnp.sign(direction_raw)).astype(jnp.int32) + car_direction_x = self._compute_direction_x(current_road, road_index_A, road_index_B) position = EntityPosition(x=position_x, y=position_y, width=width, height=height) + # === CALCULATE ROAD-BASED MOVEMENT (used when not jumping) === + road_x, road_y = self._move_on_road( + position, slope, b, speed_sign, step_size, car_direction_x, move_y, move_x + ) + # === Y MOVEMENT === # When jumping: move freely in Y direction - # When on road: only move if allowed by road geometry - new_player_y = jax.lax.cond( - move_y, - lambda _: jax.lax.cond( - is_jumping, - lambda _: position_y + speed_sign * -step_size, # Free movement while jumping - lambda _: jax.lax.cond( - self._is_on_line_for_position(position, slope, b, speed_sign, 1), - lambda _: position_y + speed_sign * -step_size, - lambda _: jnp.array(position_y, float), - operand=None, - ), - operand=None, - ), - lambda _: jnp.array(position_y, float), - operand=None, - ) + # When on road: use road-based movement result + jump_y = jnp.where(move_y, position_y + speed_sign * -step_size, position_y) + new_player_y = jnp.where(is_jumping, jump_y, road_y) # === X MOVEMENT === # When jumping: use stored_jump_slope (locked at jump start) - moves X proportionally to Y - # The slope already encodes direction (dx/dy), so multiply by Y step size and speed_sign - # When on road: only move if allowed by road geometry - new_player_x = jax.lax.cond( - move_x, - lambda _: jax.lax.cond( - is_jumping, - lambda _: position_x - speed_sign * stored_jump_slope * step_size, # Slope-based movement (negated because Y decreases going forward) - lambda _: jax.lax.cond( - self._is_on_line_for_position(position, slope, b, speed_sign, 2), - lambda _: position_x + speed_sign * car_direction_x * step_size, # Normal road movement - lambda _: jnp.array(position_x, float), - operand=None, - ), - operand=None, - ), - lambda _: jnp.array(position_x, float), - operand=None, - ) + # When on road: use road-based movement result + jump_x = jnp.where(move_x, position_x - speed_sign * stored_jump_slope * step_size, position_x) + new_player_x = jnp.where(is_jumping, jump_x, road_x) # === LANDING LOGIC === # Get the current road segment based on new Y position @@ -524,60 +565,21 @@ def _advance_player_car( landing_in_water = jnp.logical_and(is_landing, jnp.logical_not(valid_landing)) # === UPDATE ROAD STATE === - # Determine which road to assign on landing - landed_road = jax.lax.cond( - on_road_A, - lambda _: jnp.int32(0), - lambda _: jax.lax.cond( - on_road_B, - lambda _: jnp.int32(1), - lambda _: nearest_road_id, # Between roads - use nearest - operand=None, - ), - operand=None, - ) + # Determine which road to assign on landing (priority: road A > road B > nearest) + landed_road = jnp.where(on_road_A, jnp.int32(0), jnp.where(on_road_B, jnp.int32(1), nearest_road_id)) - # Update current_road - # - If landing in water: set to 2 (water/crash marker) - # - If landing successfully: set to the landed road - # - If still jumping: keep current road (frozen during jump) - # - If on road normally: update based on position - updated_current_road = jax.lax.cond( - landing_in_water, - lambda _: jnp.int32(2), # Water crash - lambda _: jax.lax.cond( - is_landing, - lambda _: landed_road, # Successfully landed - lambda _: jax.lax.cond( - is_jumping, - lambda _: current_road, # Keep road frozen while jumping - lambda _: jax.lax.cond( - current_road == 2, - lambda _: nearest_road_id, # Recover from water state - lambda _: current_road, # Normal on-road movement - operand=None, - ), - operand=None, - ), - operand=None, - ), - operand=None, - ) + # Update current_road using nested jnp.where for vectorized execution + # Priority: water crash > landing > jumping (frozen) > recover from water > normal + normal_road = jnp.where(current_road == 2, nearest_road_id, current_road) + jumping_road = jnp.where(is_jumping, current_road, normal_road) + landing_road = jnp.where(is_landing, landed_road, jumping_road) + updated_current_road = jnp.where(landing_in_water, jnp.int32(2), landing_road) # Update road indices to match current segment when not jumping - next_road_index_A = jax.lax.cond( - jnp.logical_and(jnp.logical_not(is_jumping), updated_current_road == 0), - lambda _: segment, - lambda _: road_index_A, - operand=None, - ) - - next_road_index_B = jax.lax.cond( - jnp.logical_and(jnp.logical_not(is_jumping), updated_current_road == 1), - lambda _: segment, - lambda _: road_index_B, - operand=None, - ) + not_jumping_on_road_A = jnp.logical_and(jnp.logical_not(is_jumping), updated_current_road == 0) + not_jumping_on_road_B = jnp.logical_and(jnp.logical_not(is_jumping), updated_current_road == 1) + next_road_index_A = jnp.where(not_jumping_on_road_A, segment, road_index_A) + next_road_index_B = jnp.where(not_jumping_on_road_B, segment, road_index_B) # Wrap Y position for looping track wrapped_y = -((new_player_y * -1) % self.consts.TRACK_LENGTH) @@ -614,42 +616,14 @@ def _advance_car_core( """Simplified car advancement for enemy cars (no jumping/landing logic).""" # Calculate movement timing using helper move_y, move_x, step_size, speed_sign = self._compute_movement_timing(speed, step_counter) - slope, b = self._get_slope_and_intercept_from_indices(current_road, road_index_A, road_index_B) - - direction_raw = jax.lax.cond( - current_road == 0, - lambda _: self.consts.FIRST_TRACK_CORNERS_X[road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_X[road_index_A], - lambda _: self.consts.SECOND_TRACK_CORNERS_X[road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[road_index_B], - operand=None, - ) - # Use sign, default to -1 for zero (vertical segments) - car_direction_x = jnp.where(direction_raw == 0, -1, jnp.sign(direction_raw)).astype(jnp.int32) - + car_direction_x = self._compute_direction_x(current_road, road_index_A, road_index_B) + position = EntityPosition(x=position_x, y=position_y, width=width, height=height) - - new_y = jax.lax.cond( - move_y, - lambda _: jax.lax.cond( - self._is_on_line_for_position(position, slope, b, speed_sign, 1), - lambda _: position_y + speed_sign * -step_size, - lambda _: jnp.array(position_y, float), - operand=None, - ), - lambda _: jnp.array(position_y, float), - operand=None, - ) - - new_x = jax.lax.cond( - move_x, - lambda _: jax.lax.cond( - self._is_on_line_for_position(position, slope, b, speed_sign, 2), - lambda _: position_x + speed_sign * car_direction_x * step_size, - lambda _: jnp.array(position_x, float), - operand=None, - ), - lambda _: jnp.array(position_x, float), - operand=None, + + # Use shared movement helper + new_x, new_y = self._move_on_road( + position, slope, b, speed_sign, step_size, car_direction_x, move_y, move_x ) wrapped_y = -((new_y * -1) % self.consts.TRACK_LENGTH) @@ -707,10 +681,7 @@ def check_flag_collision(flag_idx): # Check if player is close enough to collect the flag y_distance = jnp.abs(new_player_y - flag_y) x_distance = jnp.abs(player_x - flag_x) - same_road = jnp.logical_or( - jnp.logical_and(current_road == 0, flag_road == 0), - jnp.logical_and(current_road == 1, flag_road == 1), - ) + same_road = (current_road == flag_road) collision = jnp.logical_and( jnp.logical_and(y_distance < self.consts.COLLISION_THRESHOLD, x_distance < self.consts.COLLISION_THRESHOLD), @@ -825,10 +796,7 @@ def check_collision(idx): y_distance = jnp.abs(new_player_y - c_y) x_distance = jnp.abs(player_x - c_x) - same_road = jnp.logical_or( - jnp.logical_and(current_road == 0, c_road == 0), - jnp.logical_and(current_road == 1, c_road == 1), - ) + same_road = (current_road == c_road) collision = jnp.logical_and( jnp.logical_and(y_distance < self.consts.COLLISION_THRESHOLD, x_distance < self.consts.COLLISION_THRESHOLD), @@ -1709,27 +1677,22 @@ def _createBackgroundSprite(self, dimensions: Tuple[int, int]) -> jnp.ndarray: self._num_road_tiles = int(self._road_tile_offsets.shape[0]) self.enemy_sprite_names = { - self.consts.ENEMY_TYPE_CAMERO: ("camero_left", "camero_right"), - self.consts.ENEMY_TYPE_FLAG_CARRIER: ("flag_carrier_left", "flag_carrier_right"), - self.consts.ENEMY_TYPE_PICKUP: ("pick_up_truck_left", "pick_up_truck_right"), - self.consts.ENEMY_TYPE_TRUCK: ("truck_left", "truck_right"), + self.consts.ENEMY_TYPE_CAMERO: "camero_left", + self.consts.ENEMY_TYPE_FLAG_CARRIER: "flag_carrier_left", + self.consts.ENEMY_TYPE_PICKUP: "pick_up_truck_left", + self.consts.ENEMY_TYPE_TRUCK: "truck_left", } # Pre-pad enemy masks to a common shape so switch/array indexing works under jit + # Only use left sprites - right sprites are created by flipping horizontally enemy_left_raw = [ self.SHAPE_MASKS["camero_left"], self.SHAPE_MASKS["flag_carrier_left"], self.SHAPE_MASKS["pick_up_truck_left"], self.SHAPE_MASKS["truck_left"], ] - enemy_right_raw = [ - self.SHAPE_MASKS["camero_right"], - self.SHAPE_MASKS["flag_carrier_right"], - self.SHAPE_MASKS["pick_up_truck_right"], - self.SHAPE_MASKS["truck_right"], - ] - max_h = max([m.shape[0] for m in enemy_left_raw + enemy_right_raw]) - max_w = max([m.shape[1] for m in enemy_left_raw + enemy_right_raw]) + max_h = max([m.shape[0] for m in enemy_left_raw]) + max_w = max([m.shape[1] for m in enemy_left_raw]) def _pad_mask(mask): pad_h = max_h - mask.shape[0] @@ -1737,7 +1700,8 @@ def _pad_mask(mask): return jnp.pad(mask, ((0, pad_h), (0, pad_w)), constant_values=self.jr.TRANSPARENT_ID) self.enemy_left_masks = jnp.stack([_pad_mask(m) for m in enemy_left_raw], axis=0) - self.enemy_right_masks = jnp.stack([_pad_mask(m) for m in enemy_right_raw], axis=0) + # Create right-facing masks by horizontally flipping the left masks + self.enemy_right_masks = jnp.flip(self.enemy_left_masks, axis=2) # Precompute flag mask data for recoloring without special-casing pink self.flag_base_mask = self.SHAPE_MASKS["pink_flag"] @@ -1786,6 +1750,15 @@ def _get_road_sprite_sizes(self, road_files: list[str]) -> list: complete_size = int(sum(sizes)) return sizes, complete_size + def _get_x_on_road(self, y: chex.Array, road_segment: chex.Array, track_corners_x: chex.Array) -> chex.Array: + """Calculate the X position on a road given a Y coordinate and road segment.""" + y1 = self.consts.TRACK_CORNERS_Y[road_segment] + y2 = self.consts.TRACK_CORNERS_Y[road_segment + 1] + x1 = track_corners_x[road_segment] + x2 = track_corners_x[road_segment + 1] + t = jnp.where(y2 != y1, (y - y1) / (y2 - y1), 0.0) + return x1 + t * (x2 - x1) + def _find_palette_id(self, rgba: jnp.ndarray) -> int: """Return palette index for an RGBA color, falling back to first entry if missing.""" color_rgb = rgba[:3] @@ -1808,20 +1781,6 @@ def _jump_arc_offset(self, jump_cooldown: chex.Array) -> chex.Array: centered = (progress - 0.5) * 2.0 return self.consts.JUMP_ARC_HEIGHT * (1.0 - centered * centered) - def _get_x_on_road(self, y: chex.Array, road_segment: chex.Array, track_corners_x: chex.Array) -> chex.Array: - """Linear interpolation of x along the given road segment for y.""" - y1 = self.consts.TRACK_CORNERS_Y[road_segment] - y2 = self.consts.TRACK_CORNERS_Y[road_segment + 1] - x1 = track_corners_x[road_segment] - x2 = track_corners_x[road_segment + 1] - t = jax.lax.cond( - y2 != y1, - lambda _: (y - y1) / (y2 - y1), - lambda _: 0.0, - operand=None, - ) - return x1 + t * (x2 - x1) - def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.ndarray, bottomBlockSprite: jnp.ndarray, tempPointer: jnp.ndarray, blackoutSquare: jnp.ndarray) -> tuple[list, list[str]]: """Returns the asset manifest and ordered road files.""" road_dir = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/roads" @@ -1834,14 +1793,11 @@ def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.n {'name': 'background', 'type': 'background', 'data': backgroundSprite}, {'name': 'road', 'type': 'group', 'files': roads}, {'name': 'player', 'type': 'single', 'file': 'player_car.npy'}, + # Only load left-facing enemy sprites; right-facing are created by flipping {'name': 'camero_left', 'type': 'single', 'file': 'enemy_cars/camero_left.npy'}, - {'name': 'camero_right', 'type': 'single', 'file': 'enemy_cars/camero_right.npy'}, {'name': 'flag_carrier_left', 'type': 'single', 'file': 'enemy_cars/flag_carrier_left.npy'}, - {'name': 'flag_carrier_right', 'type': 'single', 'file': 'enemy_cars/flag_carrier_right.npy'}, {'name': 'pick_up_truck_left', 'type': 'single', 'file': 'enemy_cars/pick_up_truck_left.npy'}, - {'name': 'pick_up_truck_right', 'type': 'single', 'file': 'enemy_cars/pick_up_truck_right.npy'}, {'name': 'truck_left', 'type': 'single', 'file': 'enemy_cars/truck_left.npy'}, - {'name': 'truck_right', 'type': 'single', 'file': 'enemy_cars/truck_right.npy'}, {'name': 'wall_top', 'type': 'procedural', 'data': topBlockSprite}, {'name': 'wall_bottom', 'type': 'procedural', 'data': bottomBlockSprite}, {'name': 'all_flags_top', 'type': 'single', 'file': 'all_flags_top.npy'}, @@ -1908,9 +1864,10 @@ def combine(i, acc): raster = jax.lax.fori_loop(0, total_segments, combine, raster) def select_enemy_mask(enemy_type: chex.Array, going_left: chex.Array): + """Select enemy mask: left masks are base, right masks are horizontally flipped.""" left_mask = self.enemy_left_masks[enemy_type] right_mask = self.enemy_right_masks[enemy_type] - return jax.lax.cond(going_left, lambda _: left_mask, lambda _: right_mask, operand=None) + return jnp.where(going_left, left_mask, right_mask) def render_enemy(carry, enemy_idx): raster = carry From af3ceabe0d69b33d5647aa079a6cd813d9605f1d Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sun, 21 Dec 2025 16:32:00 +0100 Subject: [PATCH 28/76] reworked behaivor on step sections --- src/jaxatari/games/jax_upndown.py | 118 +++++++++++++++++++++--------- 1 file changed, 84 insertions(+), 34 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index f1b717075..005b8e0bf 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -430,6 +430,31 @@ def _is_steep_road_segment(self, current_road: chex.Array, road_index_A: chex.Ar # A segment is steep if there's no X change (or very small change) return x_diff < 1.0 + @partial(jax.jit, static_argnums=(0,)) + def _get_steep_segment_progress(self, position_y: chex.Array, current_road: chex.Array, + road_index_A: chex.Array, road_index_B: chex.Array) -> chex.Array: + """Calculate progress (0.0 to 1.0) through the current steep road segment. + + 0.0 = at the bottom (start) of the steep segment + 1.0 = at the top (end) of the steep segment + + Progress is measured in the direction of forward travel (upward = positive Y direction in game space, + but Y decreases as we go forward on the track). + """ + road_index = jnp.where(current_road == 0, road_index_A, road_index_B) + # Y coordinates of segment boundaries + y_start = self.consts.TRACK_CORNERS_Y[road_index] # Start of segment (lower Y = further ahead) + y_end = self.consts.TRACK_CORNERS_Y[road_index + 1] # End of segment (higher Y in absolute terms) + + # Calculate progress: how far through the segment are we? + # Since Y decreases as we go forward, we need to invert + segment_length = jnp.abs(y_end - y_start) + # Distance from segment start (in forward direction) + distance_from_start = jnp.abs(position_y - y_start) + + progress = jnp.where(segment_length > 0.001, distance_from_start / segment_length, 0.0) + return jnp.clip(progress, 0.0, 1.0) + @partial(jax.jit, static_argnums=(0,)) def _check_landing_position( self, @@ -871,65 +896,90 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: up = jnp.logical_or(action == Action.UP, action == Action.UPFIRE) down = jnp.logical_or(action == Action.DOWN, action == Action.DOWNFIRE) jump = jnp.logical_or(action == Action.FIRE, jnp.logical_or(action == Action.UPFIRE, action == Action.DOWNFIRE)) - player_speed = state.player_car.speed.astype(jnp.float32) - - # Use jnp.where for branchless execution + + # Check if on a steep road section FIRST (before applying speed changes) + is_on_steep_road = self._is_steep_road_segment( + state.player_car.current_road, + state.player_car.road_index_A, + state.player_car.road_index_B, + ) + + # Calculate progress through steep segment (0.0 = bottom, 1.0 = top) + steep_progress = self._get_steep_segment_progress( + state.player_car.position.y, + state.player_car.current_road, + state.player_car.road_index_A, + state.player_car.road_index_B, + ) + + # Determine if player is on steep road going up (not jumping) + on_steep_not_jumping = jnp.logical_and(is_on_steep_road, jnp.logical_not(state.is_jumping)) + + # Start with current speed + player_speed = state.player_car.speed + + # === STEEP ROAD BLOCKING LOGIC === + # On steep road: UP action has NO effect (can't accelerate while on steep section) + # Apply UP acceleration only if NOT on steep road (or if jumping over it) + can_accelerate = jnp.logical_not(on_steep_not_jumping) player_speed = jnp.where( - jnp.logical_and(state.player_car.speed < self.consts.MAX_SPEED, up), + jnp.logical_and(jnp.logical_and(player_speed < self.consts.MAX_SPEED, up), can_accelerate), player_speed + 1, player_speed, ) - + + # DOWN action always works (can brake/reverse) player_speed = jnp.where( - jnp.logical_and(state.player_car.speed > -self.consts.MAX_SPEED, down), + jnp.logical_and(player_speed > -self.consts.MAX_SPEED, down), player_speed - 1, player_speed, ) - - # Check if on a steep road section (no X direction change) and apply speed reduction - # This simulates steep road sections that require a jump to pass when going upward - is_on_steep_road = self._is_steep_road_segment( - state.player_car.current_road, - state.player_car.road_index_A, - state.player_car.road_index_B, - ) - # Only apply steep road penalty when: - # 1. Player is on a steep road section - # 2. Player is not jumping - # 3. Player has positive speed (going upward) - player_speed, steep_road_timer, jump_boost_multiplier = self._apply_steep_road_penalty( - player_speed, is_on_steep_road, state.steep_road_timer, state.is_jumping, state.jump_cooldown - ) - on_steep_going_up = jnp.logical_and( - is_on_steep_road, - jnp.logical_and( - jnp.logical_not(state.is_jumping), - player_speed > 0 - ) - ) - # Update steep road timer - increment when on steep road going up, reset otherwise (use jnp.where) + + # === STEEP ROAD SPEED REDUCTION & SLIDE BACK === + # Only apply when on steep road, not jumping, and trying to go up (positive speed) + on_steep_going_up = jnp.logical_and(on_steep_not_jumping, player_speed > 0) + + # Update steep road timer - increment when on steep road going up steep_road_timer = jnp.where( on_steep_going_up, state.steep_road_timer + 1, jnp.array(0, dtype=jnp.int32), ) - # Only reduce speed when timer reaches the interval threshold + + # Check if player has reached halfway point (50% progress through segment) + past_halfway = steep_progress >= 0.5 + + # Two behaviors based on progress: + # 1. Before halfway: gradually reduce speed using timer + # 2. At/past halfway: immediately set speed to -2 (slide back) + + # Before halfway: reduce speed periodically using timer should_reduce_speed = jnp.logical_and( on_steep_going_up, - steep_road_timer >= self.consts.STEEP_ROAD_SPEED_REDUCTION_INTERVAL + jnp.logical_and( + jnp.logical_not(past_halfway), + steep_road_timer >= self.consts.STEEP_ROAD_SPEED_REDUCTION_INTERVAL + ) ) - # Gradually reduce speed toward -2 when on steep section without jumping (use jnp.where) player_speed = jnp.where( should_reduce_speed, - jnp.maximum(player_speed - 1, jnp.int32(-2)), + jnp.maximum(player_speed - 1, jnp.int32(0)), # Reduce but not below 0 yet player_speed, ) - # Reset timer after speed reduction (use jnp.where) + # Reset timer after speed reduction steep_road_timer = jnp.where( should_reduce_speed, jnp.array(0, dtype=jnp.int32), steep_road_timer, ) + + # At/past halfway: force speed to -2 (slide back down) + should_slide_back = jnp.logical_and(on_steep_going_up, past_halfway) + player_speed = jnp.where( + should_slide_back, + jnp.int32(-3), + player_speed, + ) can_start_jump = jnp.logical_and(state.jump_cooldown == 0, state.post_jump_cooldown == 0) is_jumping = jnp.logical_or( From 1a0e4a8c9afe214c4f63ba200f73c2098458aae0 Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sun, 21 Dec 2025 17:13:18 +0100 Subject: [PATCH 29/76] move RNG key up and make functions jittable --- src/jaxatari/games/jax_upndown.py | 63 ++++++++++++++++--------------- 1 file changed, 32 insertions(+), 31 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 005b8e0bf..a99f23f4e 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -165,6 +165,7 @@ class UpNDownState(NamedTuple): is_dead: chex.Array respawn_timer: chex.Array step_counter: chex.Array + rng_key: chex.PRNGKey round_started: chex.Array movement_steps: chex.Array steep_road_timer: chex.Array # Timer for steep road speed reduction @@ -187,8 +188,6 @@ class UpNDownObservation(NamedTuple): class UpNDownInfo(NamedTuple): time: jnp.ndarray - - class JaxUpNDown(JaxEnvironment[UpNDownState, UpNDownObservation, UpNDownInfo, UpNDownConstants]): def __init__(self, consts: UpNDownConstants = None, reward_funcs: list[callable]=None): consts = consts or UpNDownConstants() @@ -734,9 +733,9 @@ def check_flag_collision(flag_idx): return new_flags, flag_score, new_flags_collected_mask @partial(jax.jit, static_argnums=(0,)) - def _collectible_step(self, state: UpNDownState, new_player_y: chex.Array, player_x: chex.Array, current_road: chex.Array) -> Tuple[Collectible, chex.Array, chex.Array]: + def _collectible_step(self, state: UpNDownState, new_player_y: chex.Array, player_x: chex.Array, current_road: chex.Array, rng_key: chex.PRNGKey) -> Tuple[Collectible, chex.Array, chex.Array, chex.PRNGKey]: """Update collectible spawning, despawning, and collection (unified for all types). - + Handles mixed-type collectibles (cherry, balloon, lollypop, ice cream) in a single pool. Type is randomized on spawn with probabilities defined in COLLECTIBLE_SPAWN_PROBABILITIES. @@ -745,10 +744,13 @@ def _collectible_step(self, state: UpNDownState, new_player_y: chex.Array, playe new_player_y: Updated player Y position after movement player_x: Current player X position current_road: Current road player is on + rng_key: PRNG key to drive spawn randomness Returns: - Tuple of (updated_collectibles, score_delta, new_spawn_timer) + Tuple of (updated_collectibles, score_delta, new_spawn_timer, new_rng_key) """ + rng_key, key1, key2, key3, key4 = jax.random.split(rng_key, 5) + # Collectible spawning logic - decrement timer and spawn when ready (use jnp.where for branchless) new_collectible_timer = jnp.where( state.collectible_spawn_timer <= 0, @@ -764,10 +766,6 @@ def _collectible_step(self, state: UpNDownState, new_player_y: chex.Array, playe has_inactive_slot = jnp.any(inactive_mask) spawn_idx = jnp.where(has_inactive_slot, first_inactive, jnp.array(0, dtype=jnp.int32)) - # Generate random spawn position using fold_in for deterministic randomness - base_key = jax.random.PRNGKey(0) - key_for_spawn = jax.random.fold_in(base_key, state.step_counter) - key1, key2, key3, key4 = jax.random.split(key_for_spawn, 4) y_spawn = jax.random.uniform(key1, minval=-900.0, maxval=-100.0) road_spawn = jnp.array(jax.random.randint(key2, shape=(), minval=0, maxval=2), dtype=jnp.int32) color_spawn = jnp.array(jax.random.randint(key3, shape=(), minval=0, maxval=len(self.consts.COLLECTIBLE_COLORS)), dtype=jnp.int32) @@ -848,7 +846,7 @@ def check_collision(idx): active=final_active, ) - return updated_collectibles, score_delta, new_collectible_timer + return updated_collectibles, score_delta, new_collectible_timer, rng_key @partial(jax.jit, static_argnums=(0,)) def _death_step(self, state: UpNDownState) -> UpNDownState: @@ -1125,14 +1123,15 @@ def _collectible_step_main(self, state: UpNDownState) -> UpNDownState: player_x = state.player_car.position.x current_road = state.player_car.current_road - updated_collectibles, collectible_score, new_collectible_timer = self._collectible_step( - state, new_player_y, player_x, current_road + updated_collectibles, collectible_score, new_collectible_timer, rng_key = self._collectible_step( + state, new_player_y, player_x, current_road, state.rng_key ) return state._replace( score=state.score + collectible_score, collectibles=updated_collectibles, collectible_spawn_timer=new_collectible_timer, + rng_key=rng_key, ) @partial(jax.jit, static_argnums=(0,)) @@ -1205,9 +1204,7 @@ def init_direction(seg, road): @partial(jax.jit, static_argnums=(0,)) def _enemy_step_main(self, state: UpNDownState) -> UpNDownState: """Spawn and move enemy cars with adaptive spawning for consistent enemy presence.""" - base_key = jax.random.PRNGKey(2025) - step_key = jax.random.fold_in(base_key, state.step_counter) - key_spawn_offset, key_spawn_side, key_spawn_speed, key_spawn_direction, key_spawn_type, key_spawn_sign, key_flip_root = jax.random.split(step_key, 7) + rng_key, key_spawn_offset, key_spawn_side, key_spawn_speed, key_spawn_direction, key_spawn_type, key_spawn_sign, key_flip_root = jax.random.split(state.rng_key, 8) active_mask = state.enemy_cars.active active_count = jnp.sum(active_mask.astype(jnp.int32)) @@ -1361,14 +1358,13 @@ def _enemy_step_main(self, state: UpNDownState) -> UpNDownState: return state._replace( enemy_cars=next_enemy_cars, enemy_spawn_timer=spawn_timer, + rng_key=rng_key, ) @partial(jax.jit, static_argnums=(0,)) def _respawn_after_collision(self, state: UpNDownState, new_lives: chex.Array) -> UpNDownState: """Respawn the player on a random road while preserving score and flags.""" - base_key = jax.random.PRNGKey(1337) - key_spawn = jax.random.fold_in(base_key, state.step_counter) - road_key, enemy_key = jax.random.split(key_spawn, 2) + rng_key, road_key, enemy_key = jax.random.split(state.rng_key, 3) player_start_y = jnp.array(0.0) start_segment = jnp.array(0, dtype=jnp.int32) @@ -1421,6 +1417,7 @@ def _respawn_after_collision(self, state: UpNDownState, new_lives: chex.Array) - collectible_spawn_timer=jnp.array(0, dtype=jnp.int32), enemy_cars=enemy_cars, enemy_spawn_timer=jnp.array(self.consts.ENEMY_SPAWN_INTERVAL_BASE, dtype=jnp.int32), + rng_key=rng_key, ) @partial(jax.jit, static_argnums=(0,)) @@ -1515,15 +1512,11 @@ def _passive_score_step_main(self, state: UpNDownState) -> UpNDownState: bonus = jnp.where(should_award, jnp.int32(self.consts.PASSIVE_SCORE_AMOUNT), jnp.int32(0)) return state._replace(score=state.score + bonus) + + @partial(jax.jit, static_argnums=(0,)) + def _reset_jit(self, key: chex.PRNGKey) -> Tuple[UpNDownObservation, UpNDownState]: + rng_key, flag_key, enemy_key = jax.random.split(key, 3) - - def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: - # Initialize flags at random positions along the track - # Use key for randomness if provided, otherwise use default positions - if key is None: - key = jax.random.PRNGKey(42) - - key, flag_key, enemy_key = jax.random.split(key, 3) # Evenly spread flags along the track with small jitter base_y = jnp.linspace(-900.0, -100.0, self.consts.NUM_FLAGS) jitter = jax.random.uniform(flag_key, shape=(self.consts.NUM_FLAGS,), minval=-40.0, maxval=40.0) @@ -1531,13 +1524,13 @@ def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: # Alternate roads 0/1 for variety flag_roads = jnp.arange(self.consts.NUM_FLAGS) % 2 - + # Calculate which road segment each flag is on based on Y position flag_segments = jax.vmap(self._get_road_segment)(flag_y_offsets) - + # Each flag color index corresponds to its position (0-7) flag_color_indices = jnp.arange(self.consts.NUM_FLAGS) - + flags = Flag( y=flag_y_offsets, road=flag_roads, @@ -1545,14 +1538,14 @@ def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: color_idx=flag_color_indices, collected=jnp.zeros(self.consts.NUM_FLAGS, dtype=jnp.bool_), ) - + # Initialize collectibles as all inactive (will spawn dynamically with mixed types) collectibles = self._initialize_collectibles() # Seed initial visible enemies spaced around the player player_start_y = jnp.array(0.0) enemy_cars = self._initialize_enemies(enemy_key, player_start_y) - + state = UpNDownState( score=0, lives=jnp.array(self.consts.INITIAL_LIVES, dtype=jnp.int32), @@ -1578,6 +1571,7 @@ def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: type=0, ), step_counter=jnp.array(0), + rng_key=rng_key, round_started=jnp.array(False), movement_steps=jnp.array(0), steep_road_timer=jnp.array(0, dtype=jnp.int32), @@ -1591,6 +1585,12 @@ def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: ) initial_obs = self._get_observation(state) return initial_obs, state + + @partial(jax.jit, static_argnums=(0,)) + def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: + if key is None: + key = jax.random.PRNGKey(42) + return self._reset_jit(key) @partial(jax.jit, static_argnums=(0,)) def step(self, state: UpNDownState, action: chex.Array) -> Tuple[UpNDownObservation, UpNDownState, float, bool, UpNDownInfo]: @@ -1617,6 +1617,7 @@ def render(self, state: UpNDownState) -> jnp.ndarray: frame = self.renderer.render(state) return jnp.asarray(frame, dtype=jnp.uint8) + @partial(jax.jit, static_argnums=(0,)) def _get_observation(self, state: UpNDownState): # Clamp to screen-friendly coordinates so observation_space.contains passes x = jnp.int32(jnp.clip(state.player_car.position.x, 0, 160)) From e03013fd00a1c37f26cdeb260849a083bd3d1a30 Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sun, 21 Dec 2025 18:34:10 +0100 Subject: [PATCH 30/76] add confirmation logic to start a new round --- src/jaxatari/games/jax_upndown.py | 503 +++++++++++++++++++++++------- 1 file changed, 396 insertions(+), 107 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index a99f23f4e..ca04cba7c 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -28,14 +28,14 @@ class UpNDownConstants(NamedTuple): ENEMY_SPAWN_INTERVAL_MAX: int = 60 # Max spawn interval when many enemies exist ENEMY_MIN_VISIBLE_COUNT: int = 2 # Minimum enemies to keep on screen ENEMY_VISIBLE_DISTANCE: int = 120 # Distance within which enemies are considered "visible" - ENEMY_DESPAWN_DISTANCE: int = 250 + ENEMY_DESPAWN_DISTANCE: int = 250 ENEMY_SPEED_MIN: int = 3 ENEMY_SPEED_MAX: int = 5 ENEMY_DIRECTION_SWITCH_PROB: float = 0.0001 - ENEMY_SPAWN_OFFSET_MIN: float = 70.0 # Closer spawn distance + ENEMY_SPAWN_OFFSET_MIN: float = 70.0 # Closer spawn distance ENEMY_SPAWN_OFFSET_MAX: float = 130.0 # Max spawn offset ENEMY_MIN_SPAWN_GAP: float = 25.0 # Reduced gap between spawns - ENEMY_MAX_AGE: int = 1900 + ENEMY_MAX_AGE: int = 1900 INITIAL_ENEMY_COUNT: int = 4 INITIAL_ENEMY_BASE_OFFSET: float = 35.0 # Closer initial enemies INITIAL_ENEMY_GAP: float = 25.0 # Tighter initial spacing @@ -59,9 +59,9 @@ class UpNDownConstants(NamedTuple): PASSIVE_SCORE_AMOUNT: int = 10 # Points awarded for passive scoring COLLISION_THRESHOLD: float = 5.0 # Distance threshold for flag/collectible collision TRACK_LENGTH: int = 1036 - FIRST_TRACK_CORNERS_X: chex.Array = jnp.array([30, 75, 128, 75, 21, 75, 131, 111, 150, 95, 150, 115, 150, 108, 150, 115, 115, 75, 18, 38, 67, 38, 38, 20, 64, 30]) + FIRST_TRACK_CORNERS_X: chex.Array = jnp.array([30, 75, 128, 75, 21, 75, 131, 111, 150, 95, 150, 115, 150, 108, 150, 115, 115, 75, 18, 38, 67, 38, 38, 20, 64, 30]) TRACK_CORNERS_Y: chex.Array = jnp.array([0, -40, -98, -155, -203, -268, -327, -347, -382, -467, -525, -565, -597, -625, -670, -705, -738, -788, -838, -862, -898, -925, -950, -972, -1000, -1035]) - SECOND_TRACK_CORNERS_X: chex.Array = jnp.array([115, 75, 20, 75, 133, 75, 22, 37, 63, 27, 66, 30, 63, 24, 60, 38, 38, 75, 131, 111, 150, 118, 118, 98, 150, 115]) + SECOND_TRACK_CORNERS_X: chex.Array = jnp.array([115, 75, 20, 75, 133, 75, 22, 37, 63, 27, 66, 30, 63, 24, 60, 38, 38, 75, 131, 111, 150, 118, 118, 98, 150, 115]) PLAYER_SIZE: Tuple[int, int] = (4, 16) INITIAL_ROAD_POS_Y: int = 25 # Flag constants - 8 flags with different colors matching the top row @@ -69,7 +69,7 @@ class UpNDownConstants(NamedTuple): # Flag colors as RGBA values (matching the top row from left to right) FLAG_COLORS: chex.Array = jnp.array([ [184, 50, 50, 255], # Red - [181, 83, 40, 255], # Orange + [181, 83, 40, 255], # Orange [162, 98, 33, 255], # Dark orange [134, 134, 29, 255], # Yellow/olive [200, 72, 72, 255], # Pink (original) @@ -179,15 +179,46 @@ class UpNDownState(NamedTuple): # Enemy cars - dynamic spawning and movement enemy_cars: EnemyCars enemy_spawn_timer: chex.Array + # Death/respawn state - player is dead and waiting for input to respawn + awaiting_respawn: chex.Array # True when player died and is waiting for input + # Round start state - everything frozen and hidden until player presses input + awaiting_round_start: chex.Array # True at game start and after respawn until input received + # Input debounce - requires button release before next input triggers round start + input_released: chex.Array # True when player has released all buttons since last state change class UpNDownObservation(NamedTuple): - player: EntityPosition + """Complete observation for RL agents in Up N Down. + + Reuses existing game classes for consistency: + - player_car: Car with EntityPosition, speed, type, road info + - enemy_cars: EnemyCars pool with positions, speeds, types, active flags + - flags: Flag with y, road, segment, color, collected status + - collectibles: Collectible with positions, types, active status + - Additional game state: score, lives, jumping status, etc. + """ + player_car: Car # Reuse existing Car class + enemy_cars: EnemyCars # Reuse existing EnemyCars class + flags: Flag # Reuse existing Flag class + collectibles: Collectible # Reuse existing Collectible class + flags_collected_mask: jnp.ndarray # Shape (NUM_FLAGS,) - boolean mask + player_score: jnp.ndarray + lives: jnp.ndarray + is_jumping: jnp.ndarray # Whether player is currently jumping + jump_cooldown: jnp.ndarray # Frames remaining in jump + is_on_steep_road: jnp.ndarray # Whether currently on steep section + round_started: jnp.ndarray # Whether player has started moving + class UpNDownInfo(NamedTuple): - time: jnp.ndarray + """Additional info for debugging and analysis.""" + step_counter: jnp.ndarray # Total steps taken + difficulty: jnp.ndarray # Current difficulty level + movement_steps: jnp.ndarray # Steps since round started + jump_slope: jnp.ndarray # Current jump trajectory slope + player_road_segment: jnp.ndarray # Current road segment index class JaxUpNDown(JaxEnvironment[UpNDownState, UpNDownObservation, UpNDownInfo, UpNDownConstants]): def __init__(self, consts: UpNDownConstants = None, reward_funcs: list[callable]=None): consts = consts or UpNDownConstants() @@ -204,7 +235,22 @@ def __init__(self, consts: UpNDownConstants = None, reward_funcs: list[callable] Action.DOWN, Action.DOWNFIRE, ] - self.obs_size = 3*4+1+1 + # Calculate obs_size based on observation structure: + # Player car: 8 values (x, y, w, h, speed, type, road, direction_x) + # Enemy cars: MAX_ENEMY_CARS * 8 = 8 * 8 = 64 (x, y, w, h, speed, type, road, active per car) + # Flags: NUM_FLAGS * 5 = 8 * 5 = 40 (y, road, segment, color, collected per flag) + # Collectibles: MAX_COLLECTIBLES * 5 = 1 * 5 = 5 (y, x, road, type, active per collectible) + # Flags collected mask: NUM_FLAGS = 8 + # Score, lives, is_jumping, jump_cooldown, is_on_steep_road, round_started: 6 + # Total: 8 + 64 + 40 + 5 + 8 + 6 = 131 + self.obs_size = ( + 8 + # player car + self.consts.MAX_ENEMY_CARS * 8 + # enemy cars + self.consts.NUM_FLAGS * 5 + # flags + self.consts.MAX_COLLECTIBLES * 5 + # collectibles + self.consts.NUM_FLAGS + # flags_collected_mask + 6 # score, lives, is_jumping, jump_cooldown, is_on_steep_road, round_started + ) # Speed dividers for movement timing (indexed by speed level) self._speed_dividers = jnp.array([0, 1, 2, 4, 8, 16, 16]) @@ -850,41 +896,41 @@ def check_collision(idx): @partial(jax.jit, static_argnums=(0,)) def _death_step(self, state: UpNDownState) -> UpNDownState: - """Handle player death when on water road (index 2).""" - # Player on water road (index 2 assumed water) + """Handle player death - this is now only used for water crashes during landing. + + When the player dies: + - Lives are decremented + - is_dead is set to True + - awaiting_respawn is set to True + - Player car is moved off-screen (despawned) + - Game waits for player input before respawning + """ + # Skip if already awaiting respawn + already_awaiting = state.awaiting_respawn + + # Player on water road (index 2 assumed water) and not already dead died = jnp.logical_and( - state.player_car.current_road == 2, - ~state.is_dead, + jnp.logical_and( + state.player_car.current_road == 2, + ~state.is_dead, + ), + ~already_awaiting, ) # Use jnp.where for branchless execution lives = jnp.where(died, state.lives - 1, state.lives) - respawn_timer = jnp.where( - died, - jnp.array(self.consts.RESPAWN_DELAY_FRAMES), - jnp.maximum(state.respawn_timer - 1, 0), - ) - is_dead = jnp.logical_and( - jnp.logical_or(state.is_dead, died), - respawn_timer > 0, - ) - - # Respawn player when dead and timer expires - should_respawn = jnp.logical_and(state.is_dead, respawn_timer == 0) - new_position = state.player_car.position._replace( - x=jnp.where(should_respawn, jnp.array(self.consts.RESPAWN_X, dtype=jnp.float32), state.player_car.position.x), - y=jnp.where(should_respawn, jnp.array(self.consts.RESPAWN_Y, dtype=jnp.float32), state.player_car.position.y), - ) + is_dead = jnp.logical_or(state.is_dead, died) + awaiting_respawn = jnp.logical_or(state.awaiting_respawn, died) + + # Stop player movement but keep position (renderer will hide player when awaiting_respawn) player_car = state.player_car._replace( - position=new_position, - speed=jnp.where(should_respawn, 0, state.player_car.speed), - current_road=jnp.where(should_respawn, 0, state.player_car.current_road), + speed=jnp.where(died, 0, state.player_car.speed), ) return state._replace( lives=lives, is_dead=is_dead, - respawn_timer=respawn_timer, + awaiting_respawn=awaiting_respawn, player_car=player_car, ) @@ -1084,9 +1130,22 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: water_crash = jnp.logical_and(is_landing, updated_player_car.current_road == 2) + # On water crash, trigger death state instead of immediate respawn + def trigger_death(s): + # Stop player but keep position (renderer will hide player when awaiting_respawn) + dead_car = s.player_car._replace( + speed=jnp.array(0, dtype=jnp.int32), + ) + return s._replace( + lives=s.lives - 1, + is_dead=jnp.array(True), + awaiting_respawn=jnp.array(True), + player_car=dead_car, + ) + return jax.lax.cond( water_crash, - lambda _: self._respawn_after_collision(next_state, next_state.lives - 1), + lambda _: trigger_death(next_state), lambda _: next_state, operand=None, ) @@ -1204,7 +1263,11 @@ def init_direction(seg, road): @partial(jax.jit, static_argnums=(0,)) def _enemy_step_main(self, state: UpNDownState) -> UpNDownState: """Spawn and move enemy cars with adaptive spawning for consistent enemy presence.""" - rng_key, key_spawn_offset, key_spawn_side, key_spawn_speed, key_spawn_direction, key_spawn_type, key_spawn_sign, key_flip_root = jax.random.split(state.rng_key, 8) + # Split RNG keys - use more splits to ensure better randomization + rng_key, key_spawn_offset, key_spawn_side, key_spawn_speed, key_spawn_direction, key_spawn_type, key_spawn_sign, key_flip_root, key_extra = jax.random.split(state.rng_key, 9) + + # Further split key_spawn_type to get more entropy for type selection + key_spawn_type = jax.random.fold_in(key_spawn_type, state.step_counter) active_mask = state.enemy_cars.active active_count = jnp.sum(active_mask.astype(jnp.int32)) @@ -1417,6 +1480,9 @@ def _respawn_after_collision(self, state: UpNDownState, new_lives: chex.Array) - collectible_spawn_timer=jnp.array(0, dtype=jnp.int32), enemy_cars=enemy_cars, enemy_spawn_timer=jnp.array(self.consts.ENEMY_SPAWN_INTERVAL_BASE, dtype=jnp.int32), + awaiting_respawn=jnp.array(False), + awaiting_round_start=jnp.array(True), # Wait for input to start round after respawn + input_released=jnp.array(False), # Require button release before round can start rng_key=rng_key, ) @@ -1485,9 +1551,18 @@ def handle_late_jump(): return state._replace(score=state.score + bonus, enemy_cars=new_enemy_cars) def handle_ground_collision(): - return self._respawn_after_collision(state, state.lives - 1) + # Trigger death state - stop player but keep position (renderer hides player when awaiting_respawn) + dead_car = state.player_car._replace( + speed=jnp.array(0, dtype=jnp.int32), + ) + return state._replace( + lives=state.lives - 1, + is_dead=jnp.array(True), + awaiting_respawn=jnp.array(True), + player_car=dead_car, + ) - # Ground collision causes respawn (landing is now protected by invincibility) + # Ground collision causes death (landing is now protected by invincibility) any_fatal_collision = grounded_collision return jax.lax.cond( @@ -1582,6 +1657,9 @@ def _reset_jit(self, key: chex.PRNGKey) -> Tuple[UpNDownObservation, UpNDownStat collectible_spawn_timer=jnp.array(0, dtype=jnp.int32), enemy_cars=enemy_cars, enemy_spawn_timer=jnp.array(self.consts.ENEMY_SPAWN_INTERVAL_BASE, dtype=jnp.int32), + awaiting_respawn=jnp.array(False), + awaiting_round_start=jnp.array(True), # Start frozen until first input + input_released=jnp.array(True), # Can start immediately at game start ) initial_obs = self._get_observation(state) return initial_obs, state @@ -1595,15 +1673,62 @@ def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: @partial(jax.jit, static_argnums=(0,)) def step(self, state: UpNDownState, action: chex.Array) -> Tuple[UpNDownObservation, UpNDownState, float, bool, UpNDownInfo]: previous_state = state - state = self._player_step(state, action) - state = self._death_step(state) - - state = self._passive_score_step_main(state) - state = self._flag_step_main(state) - state = self._completion_bonus_step(state) - state = self._collectible_step_main(state) - state = self._enemy_step_main(state) - state = self._enemy_collision_step_main(state) + + any_action = action != Action.NOOP + + # Track input release - set to True when no button is pressed + input_released = jnp.where(any_action, state.input_released, jnp.array(True)) + state = state._replace(input_released=input_released) + + # Check if we're awaiting respawn - if so, check for input to trigger respawn + should_respawn = jnp.logical_and(state.awaiting_respawn, any_action) + + # Respawn if player pressed any key while awaiting + state = jax.lax.cond( + should_respawn, + lambda s: self._respawn_after_collision(s, s.lives), # lives already decremented + lambda s: s, + state, + ) + + # Check if we're awaiting round start - if so, check for input to start round + # Only start if input was released since respawn (prevents holding button through) + should_start_round = jnp.logical_and( + jnp.logical_and(state.awaiting_round_start, any_action), + state.input_released # Must have released button first + ) + state = jax.lax.cond( + should_start_round, + lambda s: s._replace(awaiting_round_start=jnp.array(False)), + lambda s: s, + state, + ) + + # Skip all game logic if awaiting respawn OR awaiting round start + is_frozen = jnp.logical_or(state.awaiting_respawn, state.awaiting_round_start) + + def run_game_logic(s): + s = self._player_step(s, action) + s = self._death_step(s) + s = self._passive_score_step_main(s) + s = self._flag_step_main(s) + s = self._completion_bonus_step(s) + s = self._collectible_step_main(s) + s = self._enemy_step_main(s) + s = self._enemy_collision_step_main(s) + return s + + def freeze_game(s): + # Only increment step counter while frozen, everything else paused + return s._replace(step_counter=s.step_counter + 1) + + # Run game logic only if not frozen + state = jax.lax.cond( + is_frozen, + freeze_game, + run_game_logic, + state, + ) done = self._get_done(state) env_reward = self._get_reward(previous_state, state) @@ -1618,39 +1743,175 @@ def render(self, state: UpNDownState) -> jnp.ndarray: return jnp.asarray(frame, dtype=jnp.uint8) @partial(jax.jit, static_argnums=(0,)) - def _get_observation(self, state: UpNDownState): - # Clamp to screen-friendly coordinates so observation_space.contains passes - x = jnp.int32(jnp.clip(state.player_car.position.x, 0, 160)) - screen_y = jnp.int32(105) - - player = EntityPosition( - x=x, - y=screen_y, - width=jnp.int32(self.consts.PLAYER_SIZE[0]), - height=jnp.int32(self.consts.PLAYER_SIZE[1]), + def _get_observation(self, state: UpNDownState) -> UpNDownObservation: + """Build complete observation for RL agents. + + Reuses existing game classes directly from state for consistency. + """ + # Check if on steep road + is_on_steep_road = self._is_steep_road_segment( + state.player_car.current_road, + state.player_car.road_index_A, + state.player_car.road_index_B, ) - return UpNDownObservation(player=player) + + return UpNDownObservation( + player_car=state.player_car, + enemy_cars=state.enemy_cars, + flags=state.flags, + collectibles=state.collectibles, + flags_collected_mask=state.flags_collected_mask, + player_score=jnp.int32(state.score), + lives=jnp.int32(state.lives), + is_jumping=jnp.int32(state.is_jumping), + jump_cooldown=jnp.int32(state.jump_cooldown), + is_on_steep_road=jnp.int32(is_on_steep_road), + round_started=jnp.int32(state.round_started), + ) + + @partial(jax.jit, static_argnums=(0,)) + def flatten_car(self, car: Car) -> jnp.ndarray: + """Flatten a Car to a 1D array.""" + return jnp.concatenate([ + jnp.array([car.position.x], dtype=jnp.float32), + jnp.array([car.position.y], dtype=jnp.float32), + jnp.array([car.position.width], dtype=jnp.float32), + jnp.array([car.position.height], dtype=jnp.float32), + jnp.array([car.speed], dtype=jnp.float32), + jnp.array([car.type], dtype=jnp.float32), + jnp.array([car.current_road], dtype=jnp.float32), + jnp.array([car.direction_x], dtype=jnp.float32), + ]) + + @partial(jax.jit, static_argnums=(0,)) + def flatten_enemy_cars(self, enemy_cars: EnemyCars) -> jnp.ndarray: + """Flatten EnemyCars to a 1D array.""" + return jnp.concatenate([ + enemy_cars.position.x, + enemy_cars.position.y, + enemy_cars.position.width, + enemy_cars.position.height, + enemy_cars.speed.astype(jnp.float32), + enemy_cars.type.astype(jnp.float32), + enemy_cars.current_road.astype(jnp.float32), + enemy_cars.active.astype(jnp.float32), + ]) + + @partial(jax.jit, static_argnums=(0,)) + def flatten_flags(self, flags: Flag) -> jnp.ndarray: + """Flatten Flag to a 1D array.""" + return jnp.concatenate([ + flags.y, + flags.road.astype(jnp.float32), + flags.road_segment.astype(jnp.float32), + flags.color_idx.astype(jnp.float32), + flags.collected.astype(jnp.float32), + ]) + + @partial(jax.jit, static_argnums=(0,)) + def flatten_collectibles(self, collectibles: Collectible) -> jnp.ndarray: + """Flatten Collectible to a 1D array.""" + return jnp.concatenate([ + collectibles.y, + collectibles.x, + collectibles.road.astype(jnp.float32), + collectibles.type_id.astype(jnp.float32), + collectibles.active.astype(jnp.float32), + ]) @partial(jax.jit, static_argnums=(0,)) def obs_to_flat_array(self, obs: UpNDownObservation) -> jnp.ndarray: + """Flatten the complete observation to a 1D array for RL. + + Order: + - Player car: 8 values (x, y, w, h, speed, type, road, direction_x) + - Enemy cars: MAX_ENEMY_CARS * 8 values (x, y, w, h, speed, type, road, active per car) + - Flags: NUM_FLAGS * 5 values (y, road, segment, color, collected per flag) + - Collectibles: MAX_COLLECTIBLES * 5 values (y, x, road, type, active per collectible) + - Flags collected mask: NUM_FLAGS values + - Score, lives, is_jumping, jump_cooldown, is_on_steep_road, round_started: 6 values + """ return jnp.concatenate([ - jnp.asarray(obs.player.x, dtype=jnp.int32).reshape(-1), - jnp.asarray(obs.player.y, dtype=jnp.int32).reshape(-1), - jnp.asarray(obs.player.height, dtype=jnp.int32).reshape(-1), - jnp.asarray(obs.player.width, dtype=jnp.int32).reshape(-1), + self.flatten_car(obs.player_car), + self.flatten_enemy_cars(obs.enemy_cars), + self.flatten_flags(obs.flags), + self.flatten_collectibles(obs.collectibles), + obs.flags_collected_mask.flatten().astype(jnp.float32), + jnp.array([obs.player_score], dtype=jnp.float32), + jnp.array([obs.lives], dtype=jnp.float32), + jnp.array([obs.is_jumping], dtype=jnp.float32), + jnp.array([obs.jump_cooldown], dtype=jnp.float32), + jnp.array([obs.is_on_steep_road], dtype=jnp.float32), + jnp.array([obs.round_started], dtype=jnp.float32), ]) def action_space(self) -> spaces.Discrete: return spaces.Discrete(6) - def observation_space(self) -> spaces: + def observation_space(self) -> spaces.Dict: + """Returns the observation space for Up N Down. + + The observation reuses existing game classes: + - player_car: Car with position (x, y, w, h), speed, type, current_road, direction_x + - enemy_cars: EnemyCars with positions, speeds, types, roads, active flags + - flags: Flag with y, road, road_segment, color_idx, collected + - collectibles: Collectible with y, x, road, type_id, active + - flags_collected_mask: boolean array of shape (NUM_FLAGS,) + - player_score: int (0-999999) + - lives: int (0-5) + - is_jumping: int (0 or 1) + - jump_cooldown: int (0-28) + - is_on_steep_road: int (0 or 1) + - round_started: int (0 or 1) + """ return spaces.Dict({ - "player": spaces.Dict({ - "x": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), - "y": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), - "width": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), - "height": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), + "player_car": spaces.Dict({ + "position": spaces.Dict({ + "x": spaces.Box(low=0, high=160, shape=(), dtype=jnp.float32), + "y": spaces.Box(low=-2000, high=0, shape=(), dtype=jnp.float32), + "width": spaces.Box(low=0, high=160, shape=(), dtype=jnp.float32), + "height": spaces.Box(low=0, high=210, shape=(), dtype=jnp.float32), + }), + "speed": spaces.Box(low=-6, high=6, shape=(), dtype=jnp.int32), + "type": spaces.Box(low=0, high=3, shape=(), dtype=jnp.int32), + "current_road": spaces.Box(low=0, high=2, shape=(), dtype=jnp.int32), + "road_index_A": spaces.Box(low=0, high=30, shape=(), dtype=jnp.int32), + "road_index_B": spaces.Box(low=0, high=30, shape=(), dtype=jnp.int32), + "direction_x": spaces.Box(low=-1, high=1, shape=(), dtype=jnp.int32), + }), + "enemy_cars": spaces.Dict({ + "position": spaces.Dict({ + "x": spaces.Box(low=0, high=160, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.float32), + "y": spaces.Box(low=-2000, high=0, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.float32), + "width": spaces.Box(low=0, high=160, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.float32), + "height": spaces.Box(low=0, high=210, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.float32), + }), + "speed": spaces.Box(low=-6, high=6, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), + "type": spaces.Box(low=0, high=3, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), + "current_road": spaces.Box(low=0, high=2, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), + "active": spaces.Box(low=0, high=1, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.bool_), + }), + "flags": spaces.Dict({ + "y": spaces.Box(low=-2000, high=0, shape=(self.consts.NUM_FLAGS,), dtype=jnp.float32), + "road": spaces.Box(low=0, high=1, shape=(self.consts.NUM_FLAGS,), dtype=jnp.int32), + "road_segment": spaces.Box(low=0, high=30, shape=(self.consts.NUM_FLAGS,), dtype=jnp.int32), + "color_idx": spaces.Box(low=0, high=7, shape=(self.consts.NUM_FLAGS,), dtype=jnp.int32), + "collected": spaces.Box(low=0, high=1, shape=(self.consts.NUM_FLAGS,), dtype=jnp.bool_), }), + "collectibles": spaces.Dict({ + "y": spaces.Box(low=-2000, high=0, shape=(self.consts.MAX_COLLECTIBLES,), dtype=jnp.float32), + "x": spaces.Box(low=0, high=160, shape=(self.consts.MAX_COLLECTIBLES,), dtype=jnp.float32), + "road": spaces.Box(low=0, high=1, shape=(self.consts.MAX_COLLECTIBLES,), dtype=jnp.int32), + "type_id": spaces.Box(low=0, high=3, shape=(self.consts.MAX_COLLECTIBLES,), dtype=jnp.int32), + "active": spaces.Box(low=0, high=1, shape=(self.consts.MAX_COLLECTIBLES,), dtype=jnp.bool_), + }), + "flags_collected_mask": spaces.Box(low=0, high=1, shape=(self.consts.NUM_FLAGS,), dtype=jnp.bool_), + "player_score": spaces.Box(low=0, high=999999, shape=(), dtype=jnp.int32), + "lives": spaces.Box(low=0, high=5, shape=(), dtype=jnp.int32), + "is_jumping": spaces.Box(low=0, high=1, shape=(), dtype=jnp.int32), + "jump_cooldown": spaces.Box(low=0, high=28, shape=(), dtype=jnp.int32), + "is_on_steep_road": spaces.Box(low=0, high=1, shape=(), dtype=jnp.int32), + "round_started": spaces.Box(low=0, high=1, shape=(), dtype=jnp.int32), }) def image_space(self) -> spaces.Box: @@ -1662,8 +1923,22 @@ def image_space(self) -> spaces.Box: ) @partial(jax.jit, static_argnums=(0,)) - def _get_info(self, state: UpNDownState, ) -> UpNDownInfo: - return UpNDownInfo(time=jnp.asarray(state.step_counter, dtype=jnp.int32)) + def _get_info(self, state: UpNDownState) -> UpNDownInfo: + """Build info dict with additional debugging/analysis data.""" + # Get current road segment for player + road_index = jnp.where( + state.player_car.current_road == 0, + state.player_car.road_index_A, + state.player_car.road_index_B, + ) + + return UpNDownInfo( + step_counter=jnp.int32(state.step_counter), + difficulty=jnp.int32(state.difficulty), + movement_steps=jnp.int32(state.movement_steps), + jump_slope=jnp.float32(state.jump_slope), + player_road_segment=jnp.int32(road_index), + ) @partial(jax.jit, static_argnums=(0,)) def _get_reward(self, previous_state: UpNDownState, state: UpNDownState): @@ -1707,11 +1982,10 @@ def _createBackgroundSprite(self, dimensions: Tuple[int, int]) -> jnp.ndarray: temp_pointer = self._createBackgroundSprite((1, 1)) blackout_square = self._createBackgroundSprite(self.consts.FLAG_BLACKOUT_SIZE) - # 2. Update asset config to include both walls + # Build asset config locally (matches other games' pattern) asset_config, road_files = self._get_asset_config(background, top_block, bottom_block, temp_pointer, blackout_square) sprite_path = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/" - # 3. Make a single call to the setup function ( self.PALETTE, self.SHAPE_MASKS, @@ -1800,6 +2074,37 @@ def _get_road_sprite_sizes(self, road_files: list[str]) -> list: sizes.append(sprite.shape[0]) complete_size = int(sum(sizes)) return sizes, complete_size + + def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.ndarray, bottomBlockSprite: jnp.ndarray, tempPointer: jnp.ndarray, blackoutSquare: jnp.ndarray) -> tuple[list, list[str]]: + """Return asset manifest and ordered road files (renderer-local like other games).""" + road_dir = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/roads" + road_files = sorted( + file for file in os.listdir(road_dir) + if file.endswith(".npy") + ) + roads = [f"roads/{file}" for file in road_files] + return [ + {'name': 'background', 'type': 'background', 'data': backgroundSprite}, + {'name': 'road', 'type': 'group', 'files': roads}, + {'name': 'player', 'type': 'single', 'file': 'player_car.npy'}, + {'name': 'camero_left', 'type': 'single', 'file': 'enemy_cars/camero_left.npy'}, + {'name': 'flag_carrier_left', 'type': 'single', 'file': 'enemy_cars/flag_carrier_left.npy'}, + {'name': 'pick_up_truck_left', 'type': 'single', 'file': 'enemy_cars/pick_up_truck_left.npy'}, + {'name': 'truck_left', 'type': 'single', 'file': 'enemy_cars/truck_left.npy'}, + {'name': 'wall_top', 'type': 'procedural', 'data': topBlockSprite}, + {'name': 'wall_bottom', 'type': 'procedural', 'data': bottomBlockSprite}, + {'name': 'all_flags_top', 'type': 'single', 'file': 'all_flags_top.npy'}, + {'name': 'all_lives_bottom', 'type': 'single', 'file': 'all_lives_bottom.npy'}, + {'name': 'score_digits', 'type': 'digits', 'pattern': 'score/score_{}.npy'}, + {'name': 'pink_flag', 'type': 'single', 'file': 'pink_flag.npy'}, + {'name': 'flag_pole', 'type': 'single', 'file': 'flag_pole.npy'}, + {'name': 'cherry', 'type': 'single', 'file': 'cherry.npy'}, + {'name': 'balloon', 'type': 'single', 'file': 'balloon.npy'}, + {'name': 'lollypop', 'type': 'single', 'file': 'lollypop.npy'}, + {'name': 'ice_cream', 'type': 'single', 'file': 'ice_cream_cone.npy'}, + {'name': 'tempPointer', 'type': 'procedural', 'data': tempPointer}, + {'name': 'blackout_square', 'type': 'procedural', 'data': blackoutSquare}, + ], roads def _get_x_on_road(self, y: chex.Array, road_segment: chex.Array, track_corners_x: chex.Array) -> chex.Array: """Calculate the X position on a road given a Y coordinate and road segment.""" @@ -1832,38 +2137,6 @@ def _jump_arc_offset(self, jump_cooldown: chex.Array) -> chex.Array: centered = (progress - 0.5) * 2.0 return self.consts.JUMP_ARC_HEIGHT * (1.0 - centered * centered) - def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.ndarray, bottomBlockSprite: jnp.ndarray, tempPointer: jnp.ndarray, blackoutSquare: jnp.ndarray) -> tuple[list, list[str]]: - """Returns the asset manifest and ordered road files.""" - road_dir = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/roads" - road_files = sorted( - file for file in os.listdir(road_dir) - if file.endswith(".npy") - ) - roads = [f"roads/{file}" for file in road_files] - return [ - {'name': 'background', 'type': 'background', 'data': backgroundSprite}, - {'name': 'road', 'type': 'group', 'files': roads}, - {'name': 'player', 'type': 'single', 'file': 'player_car.npy'}, - # Only load left-facing enemy sprites; right-facing are created by flipping - {'name': 'camero_left', 'type': 'single', 'file': 'enemy_cars/camero_left.npy'}, - {'name': 'flag_carrier_left', 'type': 'single', 'file': 'enemy_cars/flag_carrier_left.npy'}, - {'name': 'pick_up_truck_left', 'type': 'single', 'file': 'enemy_cars/pick_up_truck_left.npy'}, - {'name': 'truck_left', 'type': 'single', 'file': 'enemy_cars/truck_left.npy'}, - {'name': 'wall_top', 'type': 'procedural', 'data': topBlockSprite}, - {'name': 'wall_bottom', 'type': 'procedural', 'data': bottomBlockSprite}, - {'name': 'all_flags_top', 'type': 'single', 'file': 'all_flags_top.npy'}, - {'name': 'all_lives_bottom', 'type': 'single', 'file': 'all_lives_bottom.npy'}, - {'name': 'score_digits', 'type': 'digits', 'pattern': 'score/score_{}.npy'}, - {'name': 'pink_flag', 'type': 'single', 'file': 'pink_flag.npy'}, - {'name': 'flag_pole', 'type': 'single', 'file': 'flag_pole.npy'}, - {'name': 'cherry', 'type': 'single', 'file': 'cherry.npy'}, - {'name': 'balloon', 'type': 'single', 'file': 'balloon.npy'}, - {'name': 'lollypop', 'type': 'single', 'file': 'lollypop.npy'}, - {'name': 'ice_cream', 'type': 'single', 'file': 'ice_cream_cone.npy'}, - {'name': 'tempPointer', 'type': 'procedural', 'data': tempPointer}, - {'name': 'blackout_square', 'type': 'procedural', 'data': blackoutSquare}, - ], roads - @partial(jax.jit, static_argnums=(0,)) def render(self, state): raster = self.jr.create_object_raster(self.BACKGROUND) @@ -1928,7 +2201,12 @@ def render_enemy(carry, enemy_idx): enemy_type = state.enemy_cars.type[enemy_idx] direction_x = state.enemy_cars.direction_x[enemy_idx] screen_y = 105 + (enemy_y - state.player_car.position.y) - is_visible = jnp.logical_and(enemy_active, jnp.logical_and(screen_y > 25, screen_y < 195)) + # Hide enemies when awaiting round start or awaiting respawn + should_hide = jnp.logical_or(state.awaiting_round_start, state.awaiting_respawn) + is_visible = jnp.logical_and( + jnp.logical_and(enemy_active, jnp.logical_and(screen_y > 25, screen_y < 195)), + ~should_hide + ) enemy_mask = select_enemy_mask(enemy_type, direction_x < 0) raster = jax.lax.cond( @@ -1950,7 +2228,14 @@ def render_enemy(carry, enemy_idx): player_screen_y = jnp.int32(105 - jump_offset) player_mask = self.SHAPE_MASKS["player"] - raster_player = self.jr.render_at_clipped(raster_enemies, state.player_car.position.x, player_screen_y, player_mask) + # Skip rendering player when awaiting respawn OR awaiting round start + should_hide_player = jnp.logical_or(state.awaiting_respawn, state.awaiting_round_start) + raster_player = jax.lax.cond( + should_hide_player, + lambda _: raster_enemies, # Don't render player + lambda _: self.jr.render_at_clipped(raster_enemies, state.player_car.position.x, player_screen_y, player_mask), + operand=None, + ) wall_top_mask = self.SHAPE_MASKS["wall_top"] raster_wall_top = self.jr.render_at(raster_player, 0, 0, wall_top_mask) @@ -2002,9 +2287,11 @@ def render_flag(carry, flag_idx): operand=None, ) screen_y = 105 + (flag_y - state.player_car.position.y) + # Hide flags when awaiting round start or awaiting respawn + should_hide = jnp.logical_or(state.awaiting_round_start, state.awaiting_respawn) is_visible = jnp.logical_and( jnp.logical_and(screen_y > 25, screen_y < 195), - ~flag_collected + jnp.logical_and(~flag_collected, ~should_hide) ) color_id = self.flag_palette_ids[flag_color_idx] colored_flag_mask = jnp.where( @@ -2050,9 +2337,11 @@ def render_collectible(carry, collectible_idx): collectible_color_idx = state.collectibles.color_idx[collectible_idx] collectible_type_id = state.collectibles.type_id[collectible_idx] screen_y = 105 + (collectible_y - state.player_car.position.y) + # Hide collectibles when awaiting round start or awaiting respawn + should_hide = jnp.logical_or(state.awaiting_round_start, state.awaiting_respawn) is_visible = jnp.logical_and( jnp.logical_and(screen_y > 25, screen_y < 195), - collectible_active + jnp.logical_and(collectible_active, ~should_hide) ) def get_sprite_and_mask(type_id): From 867c805d94327a04d0c4b7c6dc203600669b7265 Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sun, 21 Dec 2025 19:37:42 +0100 Subject: [PATCH 31/76] fix observation logic for tests --- src/jaxatari/games/jax_upndown.py | 168 +++++++++++++++--------------- 1 file changed, 85 insertions(+), 83 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index ca04cba7c..7bf6abcab 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -188,28 +188,18 @@ class UpNDownState(NamedTuple): - class UpNDownObservation(NamedTuple): - """Complete observation for RL agents in Up N Down. - - Reuses existing game classes for consistency: - - player_car: Car with EntityPosition, speed, type, road info - - enemy_cars: EnemyCars pool with positions, speeds, types, active flags - - flags: Flag with y, road, segment, color, collected status - - collectibles: Collectible with positions, types, active status - - Additional game state: score, lives, jumping status, etc. - """ - player_car: Car # Reuse existing Car class - enemy_cars: EnemyCars # Reuse existing EnemyCars class - flags: Flag # Reuse existing Flag class - collectibles: Collectible # Reuse existing Collectible class - flags_collected_mask: jnp.ndarray # Shape (NUM_FLAGS,) - boolean mask - player_score: jnp.ndarray - lives: jnp.ndarray - is_jumping: jnp.ndarray # Whether player is currently jumping - jump_cooldown: jnp.ndarray # Frames remaining in jump - is_on_steep_road: jnp.ndarray # Whether currently on steep section - round_started: jnp.ndarray # Whether player has started moving + player_car: Car + enemy_cars: EnemyCars + flags: Flag + collectibles: Collectible + flags_collected_mask: chex.Array # Shape (NUM_FLAGS,) - int32 (0 or 1) + player_score: chex.Array + lives: chex.Array + is_jumping: chex.Array + jump_cooldown: chex.Array + is_on_steep_road: chex.Array + round_started: chex.Array class UpNDownInfo(NamedTuple): @@ -236,18 +226,18 @@ def __init__(self, consts: UpNDownConstants = None, reward_funcs: list[callable] Action.DOWNFIRE, ] # Calculate obs_size based on observation structure: - # Player car: 8 values (x, y, w, h, speed, type, road, direction_x) - # Enemy cars: MAX_ENEMY_CARS * 8 = 8 * 8 = 64 (x, y, w, h, speed, type, road, active per car) + # Player car: 10 values (x, y, w, h, speed, type, road, road_index_A, road_index_B, direction_x) + # Enemy cars: MAX_ENEMY_CARS * 12 = 8 * 12 = 96 (x, y, w, h, speed, type, road, road_index_A, road_index_B, direction_x, active, age) # Flags: NUM_FLAGS * 5 = 8 * 5 = 40 (y, road, segment, color, collected per flag) - # Collectibles: MAX_COLLECTIBLES * 5 = 1 * 5 = 5 (y, x, road, type, active per collectible) + # Collectibles: MAX_COLLECTIBLES * 6 = 1 * 6 = 6 (y, x, road, color_idx, type, active per collectible) # Flags collected mask: NUM_FLAGS = 8 # Score, lives, is_jumping, jump_cooldown, is_on_steep_road, round_started: 6 - # Total: 8 + 64 + 40 + 5 + 8 + 6 = 131 + # Total: 10 + 96 + 40 + 6 + 8 + 6 = 166 self.obs_size = ( - 8 + # player car - self.consts.MAX_ENEMY_CARS * 8 + # enemy cars + 10 + # player car + self.consts.MAX_ENEMY_CARS * 12 + # enemy cars (all fields) self.consts.NUM_FLAGS * 5 + # flags - self.consts.MAX_COLLECTIBLES * 5 + # collectibles + self.consts.MAX_COLLECTIBLES * 6 + # collectibles (all fields) self.consts.NUM_FLAGS + # flags_collected_mask 6 # score, lives, is_jumping, jump_cooldown, is_on_steep_road, round_started ) @@ -1746,7 +1736,7 @@ def render(self, state: UpNDownState) -> jnp.ndarray: def _get_observation(self, state: UpNDownState) -> UpNDownObservation: """Build complete observation for RL agents. - Reuses existing game classes directly from state for consistency. + Reuses existing game classes directly. Extra fields are filtered during flatten. """ # Check if on steep road is_on_steep_road = self._is_steep_road_segment( @@ -1760,7 +1750,7 @@ def _get_observation(self, state: UpNDownState) -> UpNDownObservation: enemy_cars=state.enemy_cars, flags=state.flags, collectibles=state.collectibles, - flags_collected_mask=state.flags_collected_mask, + flags_collected_mask=state.flags_collected_mask.astype(jnp.int32), player_score=jnp.int32(state.score), lives=jnp.int32(state.lives), is_jumping=jnp.int32(state.is_jumping), @@ -1773,50 +1763,57 @@ def _get_observation(self, state: UpNDownState) -> UpNDownObservation: def flatten_car(self, car: Car) -> jnp.ndarray: """Flatten a Car to a 1D array.""" return jnp.concatenate([ - jnp.array([car.position.x], dtype=jnp.float32), - jnp.array([car.position.y], dtype=jnp.float32), - jnp.array([car.position.width], dtype=jnp.float32), - jnp.array([car.position.height], dtype=jnp.float32), - jnp.array([car.speed], dtype=jnp.float32), - jnp.array([car.type], dtype=jnp.float32), - jnp.array([car.current_road], dtype=jnp.float32), - jnp.array([car.direction_x], dtype=jnp.float32), + jnp.array([car.position.x], dtype=jnp.int32), + jnp.array([car.position.y], dtype=jnp.int32), + jnp.array([car.position.width], dtype=jnp.int32), + jnp.array([car.position.height], dtype=jnp.int32), + jnp.array([car.speed], dtype=jnp.int32), + jnp.array([car.type], dtype=jnp.int32), + jnp.array([car.current_road], dtype=jnp.int32), + jnp.array([car.road_index_A], dtype=jnp.int32), + jnp.array([car.road_index_B], dtype=jnp.int32), + jnp.array([car.direction_x], dtype=jnp.int32), ]) @partial(jax.jit, static_argnums=(0,)) def flatten_enemy_cars(self, enemy_cars: EnemyCars) -> jnp.ndarray: - """Flatten EnemyCars to a 1D array.""" + """Flatten EnemyCars to a 1D array (all fields).""" return jnp.concatenate([ - enemy_cars.position.x, - enemy_cars.position.y, - enemy_cars.position.width, - enemy_cars.position.height, - enemy_cars.speed.astype(jnp.float32), - enemy_cars.type.astype(jnp.float32), - enemy_cars.current_road.astype(jnp.float32), - enemy_cars.active.astype(jnp.float32), + enemy_cars.position.x.astype(jnp.int32), + enemy_cars.position.y.astype(jnp.int32), + enemy_cars.position.width.astype(jnp.int32), + enemy_cars.position.height.astype(jnp.int32), + enemy_cars.speed.astype(jnp.int32), + enemy_cars.type.astype(jnp.int32), + enemy_cars.current_road.astype(jnp.int32), + enemy_cars.road_index_A.astype(jnp.int32), + enemy_cars.road_index_B.astype(jnp.int32), + enemy_cars.direction_x.astype(jnp.int32), + enemy_cars.active.astype(jnp.int32), + enemy_cars.age.astype(jnp.int32), ]) @partial(jax.jit, static_argnums=(0,)) def flatten_flags(self, flags: Flag) -> jnp.ndarray: """Flatten Flag to a 1D array.""" return jnp.concatenate([ - flags.y, - flags.road.astype(jnp.float32), - flags.road_segment.astype(jnp.float32), - flags.color_idx.astype(jnp.float32), - flags.collected.astype(jnp.float32), + flags.y.astype(jnp.int32), + flags.road.astype(jnp.int32), + flags.road_segment.astype(jnp.int32), + flags.color_idx.astype(jnp.int32), + flags.collected.astype(jnp.int32), ]) @partial(jax.jit, static_argnums=(0,)) def flatten_collectibles(self, collectibles: Collectible) -> jnp.ndarray: - """Flatten Collectible to a 1D array.""" + """Flatten Collectible to a 1D array (all fields).""" return jnp.concatenate([ - collectibles.y, - collectibles.x, - collectibles.road.astype(jnp.float32), - collectibles.type_id.astype(jnp.float32), - collectibles.active.astype(jnp.float32), + collectibles.y.astype(jnp.int32), + collectibles.x.astype(jnp.int32), + collectibles.road.astype(jnp.int32), + collectibles.color_idx.astype(jnp.int32), + collectibles.type_id.astype(jnp.int32), + collectibles.active.astype(jnp.int32), ]) @partial(jax.jit, static_argnums=(0,)) @@ -1824,10 +1821,10 @@ def obs_to_flat_array(self, obs: UpNDownObservation) -> jnp.ndarray: """Flatten the complete observation to a 1D array for RL. Order: - - Player car: 8 values (x, y, w, h, speed, type, road, direction_x) - - Enemy cars: MAX_ENEMY_CARS * 8 values (x, y, w, h, speed, type, road, active per car) + - Player car: 10 values (x, y, w, h, speed, type, road, road_index_A, road_index_B, direction_x) + - Enemy cars: MAX_ENEMY_CARS * 12 values (x, y, w, h, speed, type, road, road_index_A, road_index_B, direction_x, active, age) - Flags: NUM_FLAGS * 5 values (y, road, segment, color, collected per flag) - - Collectibles: MAX_COLLECTIBLES * 5 values (y, x, road, type, active per collectible) + - Collectibles: MAX_COLLECTIBLES * 6 values (y, x, road, color_idx, type, active per collectible) - Flags collected mask: NUM_FLAGS values - Score, lives, is_jumping, jump_cooldown, is_on_steep_road, round_started: 6 values """ @@ -1836,13 +1833,13 @@ def obs_to_flat_array(self, obs: UpNDownObservation) -> jnp.ndarray: self.flatten_enemy_cars(obs.enemy_cars), self.flatten_flags(obs.flags), self.flatten_collectibles(obs.collectibles), - obs.flags_collected_mask.flatten().astype(jnp.float32), - jnp.array([obs.player_score], dtype=jnp.float32), - jnp.array([obs.lives], dtype=jnp.float32), - jnp.array([obs.is_jumping], dtype=jnp.float32), - jnp.array([obs.jump_cooldown], dtype=jnp.float32), - jnp.array([obs.is_on_steep_road], dtype=jnp.float32), - jnp.array([obs.round_started], dtype=jnp.float32), + obs.flags_collected_mask.flatten().astype(jnp.int32), + jnp.array([obs.player_score], dtype=jnp.int32), + jnp.array([obs.lives], dtype=jnp.int32), + jnp.array([obs.is_jumping], dtype=jnp.int32), + jnp.array([obs.jump_cooldown], dtype=jnp.int32), + jnp.array([obs.is_on_steep_road], dtype=jnp.int32), + jnp.array([obs.round_started], dtype=jnp.int32), ]) def action_space(self) -> spaces.Discrete: @@ -1867,10 +1864,10 @@ def observation_space(self) -> spaces.Dict: return spaces.Dict({ "player_car": spaces.Dict({ "position": spaces.Dict({ - "x": spaces.Box(low=0, high=160, shape=(), dtype=jnp.float32), - "y": spaces.Box(low=-2000, high=0, shape=(), dtype=jnp.float32), - "width": spaces.Box(low=0, high=160, shape=(), dtype=jnp.float32), - "height": spaces.Box(low=0, high=210, shape=(), dtype=jnp.float32), + "x": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), + "y": spaces.Box(low=-2000, high=0, shape=(), dtype=jnp.int32), + "width": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), + "height": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), }), "speed": spaces.Box(low=-6, high=6, shape=(), dtype=jnp.int32), "type": spaces.Box(low=0, high=3, shape=(), dtype=jnp.int32), @@ -1881,31 +1878,36 @@ def observation_space(self) -> spaces.Dict: }), "enemy_cars": spaces.Dict({ "position": spaces.Dict({ - "x": spaces.Box(low=0, high=160, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.float32), - "y": spaces.Box(low=-2000, high=0, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.float32), - "width": spaces.Box(low=0, high=160, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.float32), - "height": spaces.Box(low=0, high=210, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.float32), + "x": spaces.Box(low=0, high=160, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), + "y": spaces.Box(low=-2000, high=0, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), + "width": spaces.Box(low=0, high=160, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), + "height": spaces.Box(low=0, high=210, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), }), "speed": spaces.Box(low=-6, high=6, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), "type": spaces.Box(low=0, high=3, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), "current_road": spaces.Box(low=0, high=2, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), - "active": spaces.Box(low=0, high=1, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.bool_), + "road_index_A": spaces.Box(low=0, high=30, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), + "road_index_B": spaces.Box(low=0, high=30, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), + "direction_x": spaces.Box(low=-1, high=1, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), + "active": spaces.Box(low=0, high=1, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), + "age": spaces.Box(low=0, high=10000, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), }), "flags": spaces.Dict({ - "y": spaces.Box(low=-2000, high=0, shape=(self.consts.NUM_FLAGS,), dtype=jnp.float32), + "y": spaces.Box(low=-2000, high=0, shape=(self.consts.NUM_FLAGS,), dtype=jnp.int32), "road": spaces.Box(low=0, high=1, shape=(self.consts.NUM_FLAGS,), dtype=jnp.int32), "road_segment": spaces.Box(low=0, high=30, shape=(self.consts.NUM_FLAGS,), dtype=jnp.int32), "color_idx": spaces.Box(low=0, high=7, shape=(self.consts.NUM_FLAGS,), dtype=jnp.int32), - "collected": spaces.Box(low=0, high=1, shape=(self.consts.NUM_FLAGS,), dtype=jnp.bool_), + "collected": spaces.Box(low=0, high=1, shape=(self.consts.NUM_FLAGS,), dtype=jnp.int32), }), "collectibles": spaces.Dict({ - "y": spaces.Box(low=-2000, high=0, shape=(self.consts.MAX_COLLECTIBLES,), dtype=jnp.float32), - "x": spaces.Box(low=0, high=160, shape=(self.consts.MAX_COLLECTIBLES,), dtype=jnp.float32), + "y": spaces.Box(low=-2000, high=0, shape=(self.consts.MAX_COLLECTIBLES,), dtype=jnp.int32), + "x": spaces.Box(low=0, high=160, shape=(self.consts.MAX_COLLECTIBLES,), dtype=jnp.int32), "road": spaces.Box(low=0, high=1, shape=(self.consts.MAX_COLLECTIBLES,), dtype=jnp.int32), + "color_idx": spaces.Box(low=0, high=7, shape=(self.consts.MAX_COLLECTIBLES,), dtype=jnp.int32), "type_id": spaces.Box(low=0, high=3, shape=(self.consts.MAX_COLLECTIBLES,), dtype=jnp.int32), - "active": spaces.Box(low=0, high=1, shape=(self.consts.MAX_COLLECTIBLES,), dtype=jnp.bool_), + "active": spaces.Box(low=0, high=1, shape=(self.consts.MAX_COLLECTIBLES,), dtype=jnp.int32), }), - "flags_collected_mask": spaces.Box(low=0, high=1, shape=(self.consts.NUM_FLAGS,), dtype=jnp.bool_), + "flags_collected_mask": spaces.Box(low=0, high=1, shape=(self.consts.NUM_FLAGS,), dtype=jnp.int32), "player_score": spaces.Box(low=0, high=999999, shape=(), dtype=jnp.int32), "lives": spaces.Box(low=0, high=5, shape=(), dtype=jnp.int32), "is_jumping": spaces.Box(low=0, high=1, shape=(), dtype=jnp.int32), From 4185e3137f3e5b7a7a0ac6378c057731bde6475a Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sat, 31 Jan 2026 17:14:14 +0100 Subject: [PATCH 32/76] try to implement some of the feedback --- src/jaxatari/games/jax_upndown.py | 291 ++++++++++++++++++++++-------- 1 file changed, 216 insertions(+), 75 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 7bf6abcab..5b903ab0b 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -15,9 +15,9 @@ class UpNDownConstants(NamedTuple): FRAME_SKIP: int = 4 DIFFICULTIES: chex.Array = jnp.array([0, 1, 2, 3, 4, 5]) - MAX_SPEED: int = 6 + MAX_SPEED: int = 7 INITIAL_LIVES: int = 5 - JUMP_ARC_HEIGHT: float = 18.0 + JUMP_ARC_HEIGHT: float = 22.0 RESPAWN_DELAY_FRAMES: int = 60 RESPAWN_Y: int = 0 RESPAWN_X: int = 30 @@ -58,6 +58,8 @@ class UpNDownConstants(NamedTuple): PASSIVE_SCORE_INTERVAL: int = 60 # Steps between passive score awards PASSIVE_SCORE_AMOUNT: int = 10 # Points awarded for passive scoring COLLISION_THRESHOLD: float = 5.0 # Distance threshold for flag/collectible collision + ACCELERATION_INTERVAL: int = 6 # Frames between speed changes when holding up/down + EXTRA_LIFE_THRESHOLD: int = 10000 # Score threshold for extra life TRACK_LENGTH: int = 1036 FIRST_TRACK_CORNERS_X: chex.Array = jnp.array([30, 75, 128, 75, 21, 75, 131, 111, 150, 95, 150, 115, 150, 108, 150, 115, 115, 75, 18, 38, 67, 38, 38, 20, 64, 30]) TRACK_CORNERS_Y: chex.Array = jnp.array([0, -40, -98, -155, -203, -268, -327, -347, -382, -467, -525, -565, -597, -625, -670, -705, -738, -788, -838, -862, -898, -925, -950, -972, -1000, -1035]) @@ -185,6 +187,9 @@ class UpNDownState(NamedTuple): awaiting_round_start: chex.Array # True at game start and after respawn until input received # Input debounce - requires button release before next input triggers round start input_released: chex.Array # True when player has released all buttons since last state change + jump_key_released: chex.Array # True if jump button was NOT pressed in previous step + last_extra_life_score: chex.Array # Score at which last extra life was awarded + jump_total_duration: chex.Array # Total duration of the current/last jump for rendering arc @@ -242,7 +247,7 @@ def __init__(self, consts: UpNDownConstants = None, reward_funcs: list[callable] 6 # score, lives, is_jumping, jump_cooldown, is_on_steep_road, round_started ) # Speed dividers for movement timing (indexed by speed level) - self._speed_dividers = jnp.array([0, 1, 2, 4, 8, 16, 16]) + self._speed_dividers = jnp.array([0, 1, 2, 4, 8, 16, 16, 16, 16]) @partial(jax.jit, static_argnums=(0,)) def _compute_movement_timing(self, speed: chex.Array, step_counter: chex.Array) -> Tuple[chex.Array, chex.Array, chex.Array, chex.Array]: @@ -261,7 +266,7 @@ def _compute_movement_timing(self, speed: chex.Array, step_counter: chex.Array) move_y = jnp.logical_and((step_counter % period) == (half_period % period), speed != 0) move_x = jnp.logical_and((step_counter % period) == 0, speed != 0) - step_size = jnp.where(speed_index >= self.consts.MAX_SPEED, 1.5, 1.0) + step_size = jnp.where(speed_index >= 6, 1.5 + (speed_index - 6) * 0.2, 1.0) return move_y, move_x, step_size, speed_sign def _apply_steep_road_penalty( @@ -547,6 +552,7 @@ def _advance_player_car( car_type: chex.Array, is_landing: chex.Array, stored_jump_slope: chex.Array, + jump_progress: chex.Array, ) -> Car: """ Advance the player car position. @@ -574,16 +580,59 @@ def _advance_player_car( position, slope, b, speed_sign, step_size, car_direction_x, move_y, move_x ) + # === JUMP PHYSICS NORMALIZATION === + # Normalize jump velocity so total speed (Euclidean) matches 'step_size' + # Without this, diagonal jumps cover more distance per frame than straight road movement + # stored_jump_slope is dX/dY + # Scaling factor = 1 / sqrt(1 + slope^2) + jump_speed_scaling = 1.0 / jnp.sqrt(1.0 + stored_jump_slope**2) + jump_step_size = step_size * jump_speed_scaling + # === Y MOVEMENT === - # When jumping: move freely in Y direction + # When jumping: move freely in Y direction but with normalized speed # When on road: use road-based movement result - jump_y = jnp.where(move_y, position_y + speed_sign * -step_size, position_y) + # Note: We must apply step_y on move_y ticks to keep sync with engine heartbeat + jump_y = jnp.where(move_y, position_y + speed_sign * -jump_step_size, position_y) new_player_y = jnp.where(is_jumping, jump_y, road_y) # === X MOVEMENT === # When jumping: use stored_jump_slope (locked at jump start) - moves X proportionally to Y - # When on road: use road-based movement result - jump_x = jnp.where(move_x, position_x - speed_sign * stored_jump_slope * step_size, position_x) + # Use jump_step_size to maintain correct trajectory and speed + # X step = slope * Y step magnitude = slope * jump_step_size + raw_jump_x = jnp.where(move_x, position_x - speed_sign * stored_jump_slope * jump_step_size, position_x) + + # === AIR STEERING / MAGNETISM === + # Gradually steer towards the nearest road while in the air to prevent "teleporting" on landing + segment_curr = self._get_road_segment(new_player_y) + road_A_x_curr = self._get_x_on_road(new_player_y, segment_curr, self.consts.FIRST_TRACK_CORNERS_X) + road_B_x_curr = self._get_x_on_road(new_player_y, segment_curr, self.consts.SECOND_TRACK_CORNERS_X) + + dist_A = jnp.abs(raw_jump_x - road_A_x_curr) + dist_B = jnp.abs(raw_jump_x - road_B_x_curr) + + # Find closest road center + target_road_x = jnp.where(dist_A < dist_B, road_A_x_curr, road_B_x_curr) + dist_to_target = target_road_x - raw_jump_x + + # Only nudge in the last 25% of the jump (progress > 0.75) + # when reasonably close to a road (within 2x tolerance) + # and only when player is between the two roads + + is_late_jump = jump_progress > 0.75 + is_reasonably_close = jnp.abs(dist_to_target) < (self.consts.LANDING_TOLERANCE * 2.0) + + # Check if player is between the two roads + min_road_x_curr = jnp.minimum(road_A_x_curr, road_B_x_curr) + max_road_x_curr = jnp.maximum(road_A_x_curr, road_B_x_curr) + is_between_roads = jnp.logical_and(raw_jump_x > min_road_x_curr, raw_jump_x < max_road_x_curr) + + should_magnet = jnp.logical_and(is_late_jump, jnp.logical_and(is_reasonably_close, is_between_roads)) + + # Nudge factor: reduced to 2% steering strength (very subtle) + nudge_amount = dist_to_target * 0.08 + + jump_x = raw_jump_x + jnp.where(should_magnet, nudge_amount, 0.0) + new_player_x = jnp.where(is_jumping, jump_x, road_x) # === LANDING LOGIC === @@ -617,12 +666,41 @@ def _advance_player_car( # Valid landing: on a road OR between roads (will snap to nearest) valid_landing = jnp.logical_or(on_any_road, between_roads) + # Bridge crossing physics: if speed is high, we can "skip" small water gaps (land on nearest road) + # In original game, bridges allow crossing without jumping if you have speed + can_bridge_gap = jnp.abs(speed) >= 5 + # If landing and between roads but not directly on a road, snap to nearest road should_snap = jnp.logical_and(is_landing, jnp.logical_and(between_roads, jnp.logical_not(on_any_road))) - final_player_x = jnp.where(should_snap, nearest_road_x, new_player_x) + # Also snap if we are "in water" but have speed to bridge the gap + should_snap_bridge = jnp.logical_and(is_landing, jnp.logical_and(can_bridge_gap, jnp.logical_not(valid_landing))) + + final_player_x = jnp.where(jnp.logical_or(should_snap, should_snap_bridge), nearest_road_x, new_player_x) - # Water landing (crash): landing outside the valid road area - landing_in_water = jnp.logical_and(is_landing, jnp.logical_not(valid_landing)) + # Water landing (crash): Only if NOT on road AND NOT between roads (i.e., landed completely outside) + # User clarification: "crashing should only be possible if you dont land in betweeen or on the roads" + + # Safe if: ON ROAD or BETWEEN ROADS + is_safe_landing = jnp.logical_or(on_any_road, between_roads) + + landing_in_water = jnp.logical_and( + is_landing, + jnp.logical_not(is_safe_landing) + ) + + # Snap logic: + # If landing BETWEEN roads but not ON a road -> snap to nearest (safe!) + # (Outside landings are now crashes, so no need to snap them) + should_snap = jnp.logical_and(is_landing, jnp.logical_and(between_roads, jnp.logical_not(on_any_road))) + + # Also snap if bridging (fast jump across water gap) + should_snap_bridge = jnp.logical_and(is_landing, jnp.logical_and(between_roads, can_bridge_gap)) + + final_player_x = jnp.where( + jnp.logical_or(should_snap, should_snap_bridge), + nearest_road_x, + new_player_x + ) # === UPDATE ROAD STATE === # Determine which road to assign on landing (priority: road A > road B > nearest) @@ -712,44 +790,23 @@ def _advance_car_core( @partial(jax.jit, static_argnums=(0,)) def _flag_step(self, state: UpNDownState, new_player_y: chex.Array, player_x: chex.Array, current_road: chex.Array) -> Tuple[Flag, chex.Array, chex.Array]: - """Update flag collection state and score. + """Update flag collection state and score (vectorized).""" + # Calculate flag X positions on both roads + # _get_x_on_road supports array inputs via advanced indexing + x_road_0 = self._get_x_on_road(state.flags.y, state.flags.road_segment, self.consts.FIRST_TRACK_CORNERS_X) + x_road_1 = self._get_x_on_road(state.flags.y, state.flags.road_segment, self.consts.SECOND_TRACK_CORNERS_X) - Args: - state: Current game state - new_player_y: Updated player Y position after movement - player_x: Current player X position - current_road: Current road player is on - - Returns: - Tuple of (updated_flags, score_delta, flags_collected_mask) - """ - # Check collision for each flag - def check_flag_collision(flag_idx): - flag_y = state.flags.y[flag_idx] - flag_road = state.flags.road[flag_idx] - flag_collected = state.flags.collected[flag_idx] - - # Calculate flag X position on its road - flag_segment = state.flags.road_segment[flag_idx] - flag_x = jax.lax.cond( - flag_road == 0, - lambda _: self._get_x_on_road(flag_y, flag_segment, self.consts.FIRST_TRACK_CORNERS_X), - lambda _: self._get_x_on_road(flag_y, flag_segment, self.consts.SECOND_TRACK_CORNERS_X), - operand=None, - ) - - # Check if player is close enough to collect the flag - y_distance = jnp.abs(new_player_y - flag_y) - x_distance = jnp.abs(player_x - flag_x) - same_road = (current_road == flag_road) - - collision = jnp.logical_and( - jnp.logical_and(y_distance < self.consts.COLLISION_THRESHOLD, x_distance < self.consts.COLLISION_THRESHOLD), - jnp.logical_and(same_road, ~flag_collected) - ) - return collision + flag_x = jnp.where(state.flags.road == 0, x_road_0, x_road_1) - new_collections = jax.vmap(check_flag_collision)(jnp.arange(self.consts.NUM_FLAGS)) + # Vectorized distance check + y_dist = jnp.abs(new_player_y - state.flags.y) + x_dist = jnp.abs(player_x - flag_x) + same_road = (current_road == state.flags.road) + + new_collections = jnp.logical_and( + jnp.logical_and(y_dist < self.consts.COLLISION_THRESHOLD, x_dist < self.consts.COLLISION_THRESHOLD), + jnp.logical_and(same_road, ~state.flags.collected) + ) # Update flags collected state new_flags_collected = jnp.logical_or(state.flags.collected, new_collections) @@ -929,7 +986,7 @@ def _death_step(self, state: UpNDownState) -> UpNDownState: def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: up = jnp.logical_or(action == Action.UP, action == Action.UPFIRE) down = jnp.logical_or(action == Action.DOWN, action == Action.DOWNFIRE) - jump = jnp.logical_or(action == Action.FIRE, jnp.logical_or(action == Action.UPFIRE, action == Action.DOWNFIRE)) + jump_pressed = jnp.logical_or(action == Action.FIRE, jnp.logical_or(action == Action.UPFIRE, action == Action.DOWNFIRE)) # Check if on a steep road section FIRST (before applying speed changes) is_on_steep_road = self._is_steep_road_segment( @@ -952,19 +1009,34 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: # Start with current speed player_speed = state.player_car.speed - # === STEEP ROAD BLOCKING LOGIC === + # === FRICTION & MOMENTUM LOGIC === + is_accelerating = up + is_braking = down + + # No friction - speed stays constant when no input + # Speed changes gradually (periodically, not every frame) + should_change_speed = (state.step_counter % self.consts.ACCELERATION_INTERVAL) == 0 + + # === ACCELERATION (UP) === # On steep road: UP action has NO effect (can't accelerate while on steep section) - # Apply UP acceleration only if NOT on steep road (or if jumping over it) can_accelerate = jnp.logical_not(on_steep_not_jumping) + player_speed = jnp.where( - jnp.logical_and(jnp.logical_and(player_speed < self.consts.MAX_SPEED, up), can_accelerate), + jnp.logical_and( + jnp.logical_and(should_change_speed, is_accelerating), + jnp.logical_and(player_speed < self.consts.MAX_SPEED, can_accelerate) + ), player_speed + 1, player_speed, ) + # === BRAKING (DOWN) === # DOWN action always works (can brake/reverse) player_speed = jnp.where( - jnp.logical_and(player_speed > -self.consts.MAX_SPEED, down), + jnp.logical_and( + jnp.logical_and(should_change_speed, is_braking), + player_speed > -self.consts.MAX_SPEED + ), player_speed - 1, player_speed, ) @@ -983,9 +1055,13 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: # Check if player has reached halfway point (50% progress through segment) past_halfway = steep_progress >= 0.5 + # Check if player has enough momentum to climb steep road + MIN_CLIMB_SPEED = 5 + has_momentum = player_speed >= MIN_CLIMB_SPEED + # Two behaviors based on progress: # 1. Before halfway: gradually reduce speed using timer - # 2. At/past halfway: immediately set speed to -2 (slide back) + # 2. At/past halfway: immediately slide back UNLESS we have enough momentum # Before halfway: reduce speed periodically using timer should_reduce_speed = jnp.logical_and( @@ -1007,18 +1083,25 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: steep_road_timer, ) - # At/past halfway: force speed to -2 (slide back down) - should_slide_back = jnp.logical_and(on_steep_going_up, past_halfway) + # At/past halfway: force speed to -2 (slide back down) IF momentum is lost + should_slide_back = jnp.logical_and( + on_steep_going_up, + jnp.logical_and(past_halfway, jnp.logical_not(has_momentum)) + ) player_speed = jnp.where( should_slide_back, jnp.int32(-3), player_speed, ) - can_start_jump = jnp.logical_and(state.jump_cooldown == 0, state.post_jump_cooldown == 0) + # === JUMP LOGIC === + can_start_jump = jnp.logical_and( + state.jump_cooldown == 0, + jnp.logical_and(state.post_jump_cooldown == 0, state.jump_key_released) + ) is_jumping = jnp.logical_or( jnp.logical_and(state.is_jumping, state.jump_cooldown > 0), - jnp.logical_and(state.is_on_road,jnp.logical_and(can_start_jump, jump)), + jnp.logical_and(state.is_on_road, jnp.logical_and(player_speed >= 0, jnp.logical_and(can_start_jump, jump_pressed))), ) # Detect when a new jump is starting (was not jumping, now is jumping) @@ -1067,11 +1150,18 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: # Lock slope at jump start, keep previous slope during jump (use jnp.where) jump_slope = jnp.where(starting_jump, new_jump_slope, state.jump_slope) + # Calculate dynamic jump duration based on speed + # Faster speed = shorter jump duration (covering gap faster) + # Increased base duration for more "air time" as requested + # Formula: 48 - 2 * abs(speed) -> Speed 8 = 32 frames (was 24 before) + current_jump_duration = 48 - 2 * jnp.abs(player_speed) + jump_duration = jnp.where(starting_jump, current_jump_duration.astype(jnp.int32), state.jump_total_duration) + # Use jnp.where for branchless execution of jump_cooldown jump_cooldown = jnp.where( state.jump_cooldown > 0, state.jump_cooldown - 1, - jnp.where(is_jumping, self.consts.JUMP_FRAMES, 0), + jnp.where(is_jumping, jump_duration, 0), ) # Use jnp.where for branchless execution of post_jump_cooldown @@ -1084,6 +1174,13 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: is_on_road = ~is_jumping is_landing = is_landing_now + # Calculate jump progress for magnetism + # Progress = (Total - Remaining) / Total + # Use jnp.maximum(..., 1.0) to avoid division by zero + safe_total_duration = jnp.maximum(state.jump_total_duration, 1.0) + jump_progress = (safe_total_duration - jump_cooldown.astype(jnp.float32)) / safe_total_duration + jump_progress = jnp.clip(jump_progress, 0.0, 1.0) + updated_player_car = self._advance_player_car( position_x=state.player_car.position.x, position_y=state.player_car.position.y, @@ -1098,12 +1195,16 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: car_type=state.player_car.type, is_landing=is_landing, stored_jump_slope=jump_slope, + jump_progress=jump_progress, ) # Check if a speed-changing action (UP or DOWN) was taken speed_action_taken = jnp.logical_or(up, down) # Round starts only after a speed-changing action round_started_now = jnp.logical_or(state.round_started, speed_action_taken) + + # Track jump key release for preventing held-key jumps + next_jump_key_released = jnp.logical_not(jump_pressed) next_state = state._replace( jump_cooldown=jump_cooldown, @@ -1116,6 +1217,8 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: movement_steps=jnp.where(round_started_now, state.movement_steps + 1, state.movement_steps), steep_road_timer=steep_road_timer, jump_slope=jump_slope, + jump_key_released=next_jump_key_released, + jump_total_duration=jump_duration, ) water_crash = jnp.logical_and(is_landing, updated_player_car.current_road == 2) @@ -1158,12 +1261,34 @@ def _flag_step_main(self, state: UpNDownState) -> UpNDownState: ) @partial(jax.jit, static_argnums=(0,)) - def _completion_bonus_step(self, state: UpNDownState) -> UpNDownState: - """Award bonus when all flags are collected.""" + def _level_progression_step(self, state: UpNDownState) -> UpNDownState: + """Handle level completion: award bonus and reset flags.""" all_flags_collected = jnp.all(state.flags_collected_mask) - # Use jnp.where for branchless execution + bonus = jnp.where(all_flags_collected, self.consts.ALL_FLAGS_BONUS, 0) - return state._replace(score=state.score + bonus) + + # Reset flags if all collected + new_collected = jnp.where(all_flags_collected, jnp.zeros_like(state.flags.collected), state.flags.collected) + new_mask = jnp.where(all_flags_collected, jnp.zeros_like(state.flags_collected_mask), state.flags_collected_mask) + + updated_flags = state.flags._replace(collected=new_collected) + + return state._replace( + score=state.score + bonus, + flags=updated_flags, + flags_collected_mask=new_mask + ) + + @partial(jax.jit, static_argnums=(0,)) + def _extra_life_step(self, state: UpNDownState) -> UpNDownState: + """Award extra life every 10000 points.""" + next_milestone = state.last_extra_life_score + self.consts.EXTRA_LIFE_THRESHOLD + should_award = state.score >= next_milestone + + new_lives = jnp.where(should_award, state.lives + 1, state.lives) + new_last_score = jnp.where(should_award, next_milestone, state.last_extra_life_score) + + return state._replace(lives=new_lives, last_extra_life_score=new_last_score) @partial(jax.jit, static_argnums=(0,)) def _collectible_step_main(self, state: UpNDownState) -> UpNDownState: @@ -1473,6 +1598,9 @@ def _respawn_after_collision(self, state: UpNDownState, new_lives: chex.Array) - awaiting_respawn=jnp.array(False), awaiting_round_start=jnp.array(True), # Wait for input to start round after respawn input_released=jnp.array(False), # Require button release before round can start + jump_key_released=jnp.array(True), + last_extra_life_score=state.last_extra_life_score, + jump_total_duration=jnp.array(self.consts.JUMP_FRAMES, dtype=jnp.int32), rng_key=rng_key, ) @@ -1496,9 +1624,11 @@ def _enemy_collision_step_main(self, state: UpNDownState) -> UpNDownState: # For ground collision: only trigger when enemy position is within tight distance overlap_x_ground = dx <= self.consts.GROUND_COLLISION_DISTANCE overlap_y_ground = wrapped_dy <= self.consts.GROUND_COLLISION_DISTANCE - # For late jump collision: use larger overlap based on car dimensions - overlap_x_jump = dx <= (state.player_car.position.width + state.enemy_cars.position.width) / 2.0 - overlap_y_jump = wrapped_dy <= (state.player_car.position.height + state.enemy_cars.position.height) / 2.0 + # For late jump collision: use larger overlap based on car dimensions plus extra tolerance + # "slightly more forgiving" + jump_tolerance = 4.0 + overlap_x_jump = dx <= (state.player_car.position.width + state.enemy_cars.position.width) / 2.0 + jump_tolerance + overlap_y_jump = wrapped_dy <= (state.player_car.position.height + state.enemy_cars.position.height) / 2.0 + jump_tolerance same_road = state.enemy_cars.current_road == state.player_car.current_road # Ground collision mask uses tight 3-pixel distance and same road @@ -1650,6 +1780,9 @@ def _reset_jit(self, key: chex.PRNGKey) -> Tuple[UpNDownObservation, UpNDownStat awaiting_respawn=jnp.array(False), awaiting_round_start=jnp.array(True), # Start frozen until first input input_released=jnp.array(True), # Can start immediately at game start + jump_key_released=jnp.array(True), + last_extra_life_score=jnp.array(0, dtype=jnp.int32), + jump_total_duration=jnp.array(self.consts.JUMP_FRAMES, dtype=jnp.int32), ) initial_obs = self._get_observation(state) return initial_obs, state @@ -1702,7 +1835,8 @@ def run_game_logic(s): s = self._death_step(s) s = self._passive_score_step_main(s) s = self._flag_step_main(s) - s = self._completion_bonus_step(s) + s = self._level_progression_step(s) + s = self._extra_life_step(s) s = self._collectible_step_main(s) s = self._enemy_step_main(s) s = self._enemy_collision_step_main(s) @@ -2131,9 +2265,9 @@ def _compute_flag_palette_ids(self) -> jnp.ndarray: return jnp.array([self._find_palette_id(color) for color in self.consts.FLAG_COLORS], dtype=jnp.int32) @partial(jax.jit, static_argnums=(0,)) - def _jump_arc_offset(self, jump_cooldown: chex.Array) -> chex.Array: + def _jump_arc_offset(self, jump_cooldown: chex.Array, total_duration: chex.Array) -> chex.Array: """Return a simple parabolic jump height based on remaining jump frames.""" - total = jnp.array(self.consts.JUMP_FRAMES, dtype=jnp.float32) + total = total_duration.astype(jnp.float32) remaining = jnp.array(jump_cooldown, dtype=jnp.float32) progress = jnp.clip((total - remaining) / jnp.maximum(total, 1.0), 0.0, 1.0) centered = (progress - 0.5) * 2.0 @@ -2195,13 +2329,20 @@ def select_enemy_mask(enemy_type: chex.Array, going_left: chex.Array): right_mask = self.enemy_right_masks[enemy_type] return jnp.where(going_left, left_mask, right_mask) + # Pre-cast enemy properties to optimal types for rendering BEFORE the scan loop + enemy_active_arr = state.enemy_cars.active + enemy_x_arr = state.enemy_cars.position.x.astype(jnp.int32) + enemy_y_arr = state.enemy_cars.position.y + enemy_type_arr = state.enemy_cars.type + enemy_direction_x_arr = state.enemy_cars.direction_x + def render_enemy(carry, enemy_idx): raster = carry - enemy_active = state.enemy_cars.active[enemy_idx] - enemy_x = state.enemy_cars.position.x[enemy_idx] - enemy_y = state.enemy_cars.position.y[enemy_idx] - enemy_type = state.enemy_cars.type[enemy_idx] - direction_x = state.enemy_cars.direction_x[enemy_idx] + enemy_active = enemy_active_arr[enemy_idx] + enemy_x = enemy_x_arr[enemy_idx] + enemy_y = enemy_y_arr[enemy_idx] + enemy_type = enemy_type_arr[enemy_idx] + direction_x = enemy_direction_x_arr[enemy_idx] screen_y = 105 + (enemy_y - state.player_car.position.y) # Hide enemies when awaiting round start or awaiting respawn should_hide = jnp.logical_or(state.awaiting_round_start, state.awaiting_respawn) @@ -2213,7 +2354,7 @@ def render_enemy(carry, enemy_idx): raster = jax.lax.cond( is_visible, - lambda r: self.jr.render_at(r, enemy_x.astype(jnp.int32), screen_y.astype(jnp.int32), enemy_mask), + lambda r: self.jr.render_at(r, enemy_x, screen_y.astype(jnp.int32), enemy_mask), lambda r: r, operand=raster, ) @@ -2223,7 +2364,7 @@ def render_enemy(carry, enemy_idx): jump_offset = jax.lax.cond( state.is_jumping, - lambda _: self._jump_arc_offset(state.jump_cooldown), + lambda _: self._jump_arc_offset(state.jump_cooldown, state.jump_total_duration), lambda _: jnp.array(0.0, dtype=jnp.float32), operand=None, ) From 3718ff166b104dc07702522105cbaf0b3b3fed48 Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Thu, 5 Mar 2026 14:54:53 +0100 Subject: [PATCH 33/76] add initial mod implementation --- src/jaxatari/games/jax_upndown.py | 73 ++++++++++++++-- .../games/mods/upndown_mod_plugins.py | 83 +++++++++++++++++++ src/jaxatari/games/mods/upndown_mods.py | 40 +++++++++ 3 files changed, 187 insertions(+), 9 deletions(-) create mode 100644 src/jaxatari/games/mods/upndown_mod_plugins.py create mode 100644 src/jaxatari/games/mods/upndown_mods.py diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 5b903ab0b..773e2e1b2 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -333,6 +333,46 @@ def _apply_steep_road_penalty( return final_speed, final_timer, jump_boost + @partial(jax.jit, static_argnums=(0,)) + def _sample_enemy_spawn_road(self, rng_key: chex.PRNGKey) -> chex.Array: + """Sample road index for enemy spawns. + + Extracted as a modding hook; default behavior is unchanged. + """ + return jax.random.randint(rng_key, shape=(), minval=0, maxval=2).astype(jnp.int32) + + @partial(jax.jit, static_argnums=(0,)) + def _collectible_score_values(self, state: UpNDownState, collectible_type_ids: chex.Array) -> chex.Array: + """Return score values for collectible types. + + Extracted as a modding hook; default behavior is unchanged. + """ + return self.consts.COLLECTIBLE_SCORES[collectible_type_ids] + + @partial(jax.jit, static_argnums=(0,)) + def _on_level_completed(self, state: UpNDownState) -> UpNDownState: + """Optional callback invoked only when all flags are collected. + + Default is a no-op and preserves existing game behavior. + """ + return state + + @partial(jax.jit, static_argnums=(0,)) + def _jump_speed_allows_start(self, player_speed: chex.Array) -> chex.Array: + """Return whether jump start is allowed for the current speed. + + Extracted as a modding hook; default behavior is unchanged. + """ + return player_speed >= 0 + + @partial(jax.jit, static_argnums=(0,)) + def _adjust_enemy_spawn_timer(self, state: UpNDownState, spawn_timer: chex.Array) -> chex.Array: + """Optional hook to post-process enemy spawn timer. + + Extracted as a modding hook; default behavior is unchanged. + """ + return spawn_timer + @partial(jax.jit, static_argnums=(0,)) def _get_slope_and_intercept_from_indices(self, current_road: chex.Array, road_index_A: chex.Array, road_index_B: chex.Array) -> Tuple[chex.Array, chex.Array]: """Calculate slope and intercept for the current road segment.""" @@ -925,8 +965,8 @@ def check_collision(idx): # Deactivate collected items final_active = jnp.logical_and(active_after_despawn, ~collections) - # Update score - vectorized lookup without vmap overhead - scores = self.consts.COLLECTIBLE_SCORES[spawned_type_id] + # Update score - extracted into hook for easier modding + scores = self._collectible_score_values(state, spawned_type_id) score_delta = jnp.sum(jnp.where(collections, scores, 0)) # Create final collectibles state @@ -1101,7 +1141,13 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: ) is_jumping = jnp.logical_or( jnp.logical_and(state.is_jumping, state.jump_cooldown > 0), - jnp.logical_and(state.is_on_road, jnp.logical_and(player_speed >= 0, jnp.logical_and(can_start_jump, jump_pressed))), + jnp.logical_and( + state.is_on_road, + jnp.logical_and( + self._jump_speed_allows_start(player_speed), + jnp.logical_and(can_start_jump, jump_pressed), + ), + ), ) # Detect when a new jump is starting (was not jumping, now is jumping) @@ -1273,12 +1319,19 @@ def _level_progression_step(self, state: UpNDownState) -> UpNDownState: updated_flags = state.flags._replace(collected=new_collected) - return state._replace( + next_state = state._replace( score=state.score + bonus, flags=updated_flags, flags_collected_mask=new_mask ) + return jax.lax.cond( + all_flags_collected, + lambda s: self._on_level_completed(s), + lambda s: s, + next_state, + ) + @partial(jax.jit, static_argnums=(0,)) def _extra_life_step(self, state: UpNDownState) -> UpNDownState: """Award extra life every 10000 points.""" @@ -1438,7 +1491,7 @@ def _enemy_step_main(self, state: UpNDownState) -> UpNDownState: spawn_side = jax.random.choice(key_spawn_side, jnp.array([-1.0, 1.0])) raw_spawn_y = state.player_car.position.y + spawn_side * spawn_offset spawn_y = -(((raw_spawn_y) * -1) % self.consts.TRACK_LENGTH) - spawn_road = jax.random.randint(key_spawn_direction, shape=(), minval=0, maxval=2) + spawn_road = self._sample_enemy_spawn_road(key_spawn_direction) segment_spawn = self._get_road_segment(spawn_y) spawn_x = jnp.where( @@ -1533,6 +1586,8 @@ def _enemy_step_main(self, state: UpNDownState) -> UpNDownState: age=enemy_age, ) + spawn_timer = self._adjust_enemy_spawn_timer(state, spawn_timer) + return state._replace( enemy_cars=next_enemy_cars, enemy_spawn_timer=spawn_timer, @@ -1991,7 +2046,7 @@ def observation_space(self) -> spaces.Dict: - player_score: int (0-999999) - lives: int (0-5) - is_jumping: int (0 or 1) - - jump_cooldown: int (0-28) + - jump_cooldown: int (0-48) - is_on_steep_road: int (0 or 1) - round_started: int (0 or 1) """ @@ -2003,7 +2058,7 @@ def observation_space(self) -> spaces.Dict: "width": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), "height": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), }), - "speed": spaces.Box(low=-6, high=6, shape=(), dtype=jnp.int32), + "speed": spaces.Box(low=-self.consts.MAX_SPEED, high=self.consts.MAX_SPEED, shape=(), dtype=jnp.int32), "type": spaces.Box(low=0, high=3, shape=(), dtype=jnp.int32), "current_road": spaces.Box(low=0, high=2, shape=(), dtype=jnp.int32), "road_index_A": spaces.Box(low=0, high=30, shape=(), dtype=jnp.int32), @@ -2017,7 +2072,7 @@ def observation_space(self) -> spaces.Dict: "width": spaces.Box(low=0, high=160, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), "height": spaces.Box(low=0, high=210, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), }), - "speed": spaces.Box(low=-6, high=6, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), + "speed": spaces.Box(low=-(self.consts.ENEMY_SPEED_MAX + 1), high=(self.consts.ENEMY_SPEED_MAX + 1), shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), "type": spaces.Box(low=0, high=3, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), "current_road": spaces.Box(low=0, high=2, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), "road_index_A": spaces.Box(low=0, high=30, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), @@ -2045,7 +2100,7 @@ def observation_space(self) -> spaces.Dict: "player_score": spaces.Box(low=0, high=999999, shape=(), dtype=jnp.int32), "lives": spaces.Box(low=0, high=5, shape=(), dtype=jnp.int32), "is_jumping": spaces.Box(low=0, high=1, shape=(), dtype=jnp.int32), - "jump_cooldown": spaces.Box(low=0, high=28, shape=(), dtype=jnp.int32), + "jump_cooldown": spaces.Box(low=0, high=48, shape=(), dtype=jnp.int32), "is_on_steep_road": spaces.Box(low=0, high=1, shape=(), dtype=jnp.int32), "round_started": spaces.Box(low=0, high=1, shape=(), dtype=jnp.int32), }) diff --git a/src/jaxatari/games/mods/upndown_mod_plugins.py b/src/jaxatari/games/mods/upndown_mod_plugins.py new file mode 100644 index 000000000..5382e0efa --- /dev/null +++ b/src/jaxatari/games/mods/upndown_mod_plugins.py @@ -0,0 +1,83 @@ +from functools import partial +import chex +import jax +import jax.numpy as jnp + +from jaxatari.games.jax_upndown import UpNDownState +from jaxatari.modification import JaxAtariInternalModPlugin + + +class RemoveStepRoadsMod(JaxAtariInternalModPlugin): + @partial(jax.jit, static_argnums=(0,)) + def _is_steep_road_segment(self, current_road: chex.Array, road_index_A: chex.Array, road_index_B: chex.Array) -> chex.Array: + return jnp.array(False) + + +class HigherPlayerSpeedMod(JaxAtariInternalModPlugin): + constants_overrides = { + "MAX_SPEED": 9, + } + + +class MoreCollectiblesMod(JaxAtariInternalModPlugin): + constants_overrides = { + "MAX_COLLECTIBLES": 4, + "COLLECTIBLE_SPAWN_INTERVAL": 120, + } + + +class MinCarSpawnGapMod(JaxAtariInternalModPlugin): + conflicts_with = ["progressive_car_spawn_rate"] + constants_overrides = { + "ENEMY_SPAWN_INTERVAL_BASE": 50, + } + + +class AllowJumpBackwardsMod(JaxAtariInternalModPlugin): + @partial(jax.jit, static_argnums=(0,)) + def _jump_speed_allows_start(self, player_speed: chex.Array) -> chex.Array: + return jnp.array(True) + + +class SingleLaneCarSpawnMod(JaxAtariInternalModPlugin): + @partial(jax.jit, static_argnums=(0,)) + def _sample_enemy_spawn_road(self, rng_key: chex.PRNGKey) -> chex.Array: + return jnp.array(1, dtype=jnp.int32) + + +class ProgressiveCarSpawnRateMod(JaxAtariInternalModPlugin): + conflicts_with = ["minimum_car_spawn_gap"] + + @partial(jax.jit, static_argnums=(0,)) + def _adjust_enemy_spawn_timer(self, state: UpNDownState, spawn_timer: chex.Array) -> chex.Array: + start_interval = jnp.int32(self._env.consts.ENEMY_SPAWN_INTERVAL_BASE) + min_interval = jnp.int32(8) + horizon = jnp.float32(1800.0) + + progress = jnp.clip(state.movement_steps.astype(jnp.float32) / horizon, 0.0, 1.0) + decayed_interval = jnp.round( + start_interval.astype(jnp.float32) - progress * (start_interval.astype(jnp.float32) - min_interval.astype(jnp.float32)) + ).astype(jnp.int32) + + target_interval = jnp.maximum(min_interval, decayed_interval) + return jnp.minimum(spawn_timer, target_interval) + + @partial(jax.jit, static_argnums=(0,)) + def _on_level_completed(self, state: UpNDownState) -> UpNDownState: + return state._replace( + movement_steps=jnp.array(0, dtype=jnp.int32), + enemy_spawn_timer=jnp.array(self._env.consts.ENEMY_SPAWN_INTERVAL_BASE, dtype=jnp.int32), + ) + + +class TimeDecayCollectibleValueMod(JaxAtariInternalModPlugin): + @partial(jax.jit, static_argnums=(0,)) + def _on_level_completed(self, state: UpNDownState) -> UpNDownState: + return state._replace(movement_steps=jnp.array(0, dtype=jnp.int32)) + + @partial(jax.jit, static_argnums=(0,)) + def _collectible_score_values(self, state: UpNDownState, collectible_type_ids: chex.Array) -> chex.Array: + base_scores = self._env.consts.COLLECTIBLE_SCORES[collectible_type_ids] + elapsed_decay = jnp.floor(state.movement_steps.astype(jnp.float32) / 200.0).astype(jnp.int32) + min_scores = jnp.maximum(jnp.int32(10), base_scores // 3) + return jnp.maximum(base_scores - elapsed_decay, min_scores) diff --git a/src/jaxatari/games/mods/upndown_mods.py b/src/jaxatari/games/mods/upndown_mods.py new file mode 100644 index 000000000..0dd2123b5 --- /dev/null +++ b/src/jaxatari/games/mods/upndown_mods.py @@ -0,0 +1,40 @@ +from jaxatari.modification import JaxAtariModController +from jaxatari.games.mods.upndown_mod_plugins import ( + AllowJumpBackwardsMod, + RemoveStepRoadsMod, + HigherPlayerSpeedMod, + MoreCollectiblesMod, + MinCarSpawnGapMod, + SingleLaneCarSpawnMod, + ProgressiveCarSpawnRateMod, + TimeDecayCollectibleValueMod, +) + + +UPNDOWN_MOD_REGISTRY = { + "allow_jump_backwards": AllowJumpBackwardsMod, + "remove_step_roads": RemoveStepRoadsMod, + "higher_player_speed": HigherPlayerSpeedMod, + "spawn_more_collectibles": MoreCollectiblesMod, + "minimum_car_spawn_gap": MinCarSpawnGapMod, + "single_lane_car_spawn": SingleLaneCarSpawnMod, + "progressive_car_spawn_rate": ProgressiveCarSpawnRateMod, + "collectible_value_time_decay": TimeDecayCollectibleValueMod, +} + + +class UpNDownEnvMod(JaxAtariModController): + REGISTRY = UPNDOWN_MOD_REGISTRY + + def __init__( + self, + env, + mods_config: list = [], + allow_conflicts: bool = False, + ): + super().__init__( + env=env, + mods_config=mods_config, + allow_conflicts=allow_conflicts, + registry=self.REGISTRY, + ) From bb71effc0f7928c983d1d06e41a51576a77e27b1 Mon Sep 17 00:00:00 2001 From: shaik05 Date: Fri, 6 Mar 2026 13:38:07 +0100 Subject: [PATCH 34/76] Allow backward jumping and remove steep road --- src/jaxatari/games/jax_upndown.py | 34 +++++++++++++++++++------------ 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 773e2e1b2..2af620f63 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -1026,18 +1026,26 @@ def _death_step(self, state: UpNDownState) -> UpNDownState: def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: up = jnp.logical_or(action == Action.UP, action == Action.UPFIRE) down = jnp.logical_or(action == Action.DOWN, action == Action.DOWNFIRE) - jump_pressed = jnp.logical_or(action == Action.FIRE, jnp.logical_or(action == Action.UPFIRE, action == Action.DOWNFIRE)) - - # Check if on a steep road section FIRST (before applying speed changes) - is_on_steep_road = self._is_steep_road_segment( - state.player_car.current_road, - state.player_car.road_index_A, - state.player_car.road_index_B, + jump = jnp.logical_or(action == Action.FIRE, jnp.logical_or(action == Action.UPFIRE, action == Action.DOWNFIRE)) + player_speed = state.player_car.speed.astype(jnp.int32) + + player_speed = jax.lax.cond( + jnp.logical_and(state.player_car.speed < self.consts.MAX_SPEED, up), + lambda s: s + 1, + lambda s: s, + operand=player_speed, ) - - # Calculate progress through steep segment (0.0 = bottom, 1.0 = top) - steep_progress = self._get_steep_segment_progress( - state.player_car.position.y, + + player_speed = jax.lax.cond( + jnp.logical_and(state.player_car.speed > -self.consts.MAX_SPEED, down), + lambda s: s - 1, + lambda s: s, + operand=player_speed, + ) + + # Check if on a steep road section (no X direction change) and apply speed reduction + # This simulates steep road sections that require a jump to pass when going upward + is_on_steep_road = self._is_steep_road_segment( state.player_car.current_road, state.player_car.road_index_A, state.player_car.road_index_B, @@ -1063,7 +1071,7 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: player_speed = jnp.where( jnp.logical_and( - jnp.logical_and(should_change_speed, is_accelerating), + jnp.logical_and(up,True), jnp.logical_and(player_speed < self.consts.MAX_SPEED, can_accelerate) ), player_speed + 1, @@ -1620,7 +1628,7 @@ def _respawn_after_collision(self, state: UpNDownState, new_lives: chex.Array) - width=self.consts.PLAYER_SIZE[0], height=self.consts.PLAYER_SIZE[1], ), - speed=jnp.array(0.0, dtype=jnp.float32), + speed=jnp.array(0, dtype=jnp.int32), direction_x=jnp.array(0, dtype=jnp.int32), current_road=respawn_road, road_index_A=start_segment, From 9a985bcb92111f9a150dd531a2cc13ef301bf00e Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sat, 21 Mar 2026 17:49:30 +0100 Subject: [PATCH 35/76] Revert "Allow backward jumping and remove steep road" This reverts commit 23cea0501aa2ee1a6024d94d98048013301ff7a5. --- src/jaxatari/games/jax_upndown.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 2af620f63..ff27630c3 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -1027,7 +1027,7 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: up = jnp.logical_or(action == Action.UP, action == Action.UPFIRE) down = jnp.logical_or(action == Action.DOWN, action == Action.DOWNFIRE) jump = jnp.logical_or(action == Action.FIRE, jnp.logical_or(action == Action.UPFIRE, action == Action.DOWNFIRE)) - player_speed = state.player_car.speed.astype(jnp.int32) + player_speed = state.player_car.speed player_speed = jax.lax.cond( jnp.logical_and(state.player_car.speed < self.consts.MAX_SPEED, up), @@ -1071,7 +1071,7 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: player_speed = jnp.where( jnp.logical_and( - jnp.logical_and(up,True), + jnp.logical_and(should_change_speed, is_accelerating), jnp.logical_and(player_speed < self.consts.MAX_SPEED, can_accelerate) ), player_speed + 1, @@ -1628,7 +1628,7 @@ def _respawn_after_collision(self, state: UpNDownState, new_lives: chex.Array) - width=self.consts.PLAYER_SIZE[0], height=self.consts.PLAYER_SIZE[1], ), - speed=jnp.array(0, dtype=jnp.int32), + speed=jnp.array(0.0, dtype=jnp.float32), direction_x=jnp.array(0, dtype=jnp.int32), current_road=respawn_road, road_index_A=start_segment, From 7e1c8844a2282dcb37801dcc83a39cb7a3acedce Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sat, 21 Mar 2026 20:22:26 +0100 Subject: [PATCH 36/76] revert the in main game mod implementations --- src/jaxatari/games/jax_upndown.py | 118 ++++-------------------------- 1 file changed, 15 insertions(+), 103 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index ff27630c3..4f7a3af1f 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -50,11 +50,7 @@ class UpNDownConstants(NamedTuple): LANDING_COLLISION_DISTANCE: float = 12.0 # Larger collision distance when landing (increased for easier enemy kills) GROUND_COLLISION_DISTANCE: float = 3.0 # Tight collision distance for ground collisions LATE_JUMP_ENEMY_SCORE: int = 400 - STEEP_ROAD_SPEED_REDUCTION_INTERVAL: int = 12 # Frames between each speed reduction on steep roads - STEEP_ROAD_MIN_SPEED: float = -2.0 # Minimum speed on steep roads - STEEP_ROAD_JUMP_BOOST: float = 1.5 # Multiplier for jump height on steep roads - STEEP_ROAD_RECOVERY_BOOST: float = 0.8 # Speed boost after leaving steep road - STEEP_ROAD_COOLDOWN: int = 5 + STEEP_ROAD_SPEED_REDUCTION_INTERVAL: int = 8 # Frames between each speed reduction on steep roads PASSIVE_SCORE_INTERVAL: int = 60 # Steps between passive score awards PASSIVE_SCORE_AMOUNT: int = 10 # Points awarded for passive scoring COLLISION_THRESHOLD: float = 5.0 # Distance threshold for flag/collectible collision @@ -256,8 +252,8 @@ def _compute_movement_timing(self, speed: chex.Array, step_counter: chex.Array) Returns: Tuple of (move_y, move_x, step_size, speed_sign) """ - abs_speed = jnp.abs(speed).astype(jnp.int32) - speed_index = jnp.minimum(abs_speed, self._speed_dividers.shape[0] - 1).astype(jnp.int32) + abs_speed = jnp.abs(speed) + speed_index = jnp.minimum(abs_speed, jnp.int32(self._speed_dividers.shape[0] - 1)) speed_divider = self._speed_dividers[speed_index] effective_divider = jnp.maximum(1, speed_divider) period = jnp.maximum(1, 16 // effective_divider) @@ -269,69 +265,6 @@ def _compute_movement_timing(self, speed: chex.Array, step_counter: chex.Array) step_size = jnp.where(speed_index >= 6, 1.5 + (speed_index - 6) * 0.2, 1.0) return move_y, move_x, step_size, speed_sign - def _apply_steep_road_penalty( - self, - speed: chex.Array, - is_on_steep_road: chex.Array, - steep_road_timer: chex.Array, - is_jumping: chex.Array, - jump_cooldown: chex.Array, - ) -> Tuple[chex.Array, chex.Array, chex.Array]: - """ - Apply enhanced steep road penalty with perfect balance and edge case handling. - - - Dynamically reduces speed on steep roads when going upward. - - Provides jump boost and recovery for better flow. - - Includes cooldown to prevent rapid reductions. - - Returns: (new_speed, new_timer, jump_boost_multiplier) - """ - going_up = speed > 0 - on_steep_going_up = jnp.logical_and(is_on_steep_road, going_up) - in_cooldown = steep_road_timer < 0 # Negative timer indicates cooldown - - # Increment timer only if not in cooldown and on steep road going up - timer_increment = jax.lax.cond( - jnp.logical_and(on_steep_going_up, jnp.logical_not(in_cooldown)), - lambda _: 1, - lambda _: 0, - operand=None, - ) - new_timer = steep_road_timer + timer_increment - - # Apply reduction when timer reaches interval and not in cooldown - should_reduce = jnp.logical_and( - on_steep_going_up, - jnp.logical_and(new_timer >= self.consts.STEEP_ROAD_SPEED_REDUCTION_INTERVAL, jnp.logical_not(in_cooldown)) - ) - - # Proportional reduction: stronger for higher speeds, with minimum cap - reduction_factor = jnp.maximum(0.05, speed * 0.15) # 5-15% of speed - reduced_speed = jnp.maximum(speed - reduction_factor, self.consts.STEEP_ROAD_MIN_SPEED) - - # Set cooldown after reduction (negative timer) - final_timer = jax.lax.cond( - should_reduce, - lambda _: -self.consts.STEEP_ROAD_COOLDOWN, - lambda _: new_timer, - operand=None, - ) - - # Recovery boost after leaving steep road (not jumping) - just_left_steep = jnp.logical_and(jnp.logical_not(on_steep_going_up), jnp.logical_not(is_jumping)) - recovery_boost = jax.lax.cond(just_left_steep, lambda _: self.consts.STEEP_ROAD_RECOVERY_BOOST, lambda _: 0.0, operand=None) - - # Jump boost if jumping on steep road - jump_boost = jax.lax.cond( - jnp.logical_and(on_steep_going_up, jump_cooldown > 0), - lambda _: self.consts.STEEP_ROAD_JUMP_BOOST, - lambda _: 1.0, - operand=None, - ) - - final_speed = jax.lax.cond(should_reduce, lambda _: reduced_speed + recovery_boost, lambda _: speed + recovery_boost, operand=None) - - return final_speed, final_timer, jump_boost @partial(jax.jit, static_argnums=(0,)) def _sample_enemy_spawn_road(self, rng_key: chex.PRNGKey) -> chex.Array: @@ -1026,31 +959,23 @@ def _death_step(self, state: UpNDownState) -> UpNDownState: def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: up = jnp.logical_or(action == Action.UP, action == Action.UPFIRE) down = jnp.logical_or(action == Action.DOWN, action == Action.DOWNFIRE) - jump = jnp.logical_or(action == Action.FIRE, jnp.logical_or(action == Action.UPFIRE, action == Action.DOWNFIRE)) - player_speed = state.player_car.speed - - player_speed = jax.lax.cond( - jnp.logical_and(state.player_car.speed < self.consts.MAX_SPEED, up), - lambda s: s + 1, - lambda s: s, - operand=player_speed, - ) - - player_speed = jax.lax.cond( - jnp.logical_and(state.player_car.speed > -self.consts.MAX_SPEED, down), - lambda s: s - 1, - lambda s: s, - operand=player_speed, - ) - - # Check if on a steep road section (no X direction change) and apply speed reduction - # This simulates steep road sections that require a jump to pass when going upward + jump_pressed = jnp.logical_or(action == Action.FIRE, jnp.logical_or(action == Action.UPFIRE, action == Action.DOWNFIRE)) + + # Check if on a steep road section FIRST (before applying speed changes) is_on_steep_road = self._is_steep_road_segment( state.player_car.current_road, state.player_car.road_index_A, state.player_car.road_index_B, ) + # Calculate progress through steep segment (0.0 = bottom, 1.0 = top) + steep_progress = self._get_steep_segment_progress( + state.player_car.position.y, + state.player_car.current_road, + state.player_car.road_index_A, + state.player_car.road_index_B, + ) + # Determine if player is on steep road going up (not jumping) on_steep_not_jumping = jnp.logical_and(is_on_steep_road, jnp.logical_not(state.is_jumping)) @@ -1628,7 +1553,7 @@ def _respawn_after_collision(self, state: UpNDownState, new_lives: chex.Array) - width=self.consts.PLAYER_SIZE[0], height=self.consts.PLAYER_SIZE[1], ), - speed=jnp.array(0.0, dtype=jnp.float32), + speed=jnp.array(0, dtype=jnp.int32), direction_x=jnp.array(0, dtype=jnp.int32), current_road=respawn_road, road_index_A=start_segment, @@ -2160,19 +2085,6 @@ def __init__(self, consts: UpNDownConstants = None): channels=3, #downscale=(84, 84) ) - def _createBackgroundSprite(self, dimensions: Tuple[int, int]) -> jnp.ndarray: - - height, width = dimensions - # Create a vertical gradient: blue at top, lighter blue at bottom - top_color = jnp.array([135, 206, 235, 255], dtype=jnp.uint8) # Sky blue - bottom_color = jnp.array([173, 216, 230, 255], dtype=jnp.uint8) # Lighter sky blue - - # Linear interpolation for gradient - y_coords = jnp.arange(height, dtype=jnp.float32) / (height - 1) - gradient = jnp.outer(y_coords, bottom_color - top_color) + top_color - gradient = jnp.clip(gradient, 0, 255).astype(jnp.uint8) - - return gradient self.jr = render_utils.JaxRenderingUtils(self.config) background = self._createBackgroundSprite(self.config.game_dimensions) From d767f4dbd3cdb1e8e3fea7cb6d02a9e35b59cefd Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sat, 21 Mar 2026 20:23:31 +0100 Subject: [PATCH 37/76] moved mod plugin file to new folder --- src/jaxatari/games/mods/{ => upndown}/upndown_mod_plugins.py | 0 src/jaxatari/games/mods/upndown_mods.py | 2 +- 2 files changed, 1 insertion(+), 1 deletion(-) rename src/jaxatari/games/mods/{ => upndown}/upndown_mod_plugins.py (100%) diff --git a/src/jaxatari/games/mods/upndown_mod_plugins.py b/src/jaxatari/games/mods/upndown/upndown_mod_plugins.py similarity index 100% rename from src/jaxatari/games/mods/upndown_mod_plugins.py rename to src/jaxatari/games/mods/upndown/upndown_mod_plugins.py diff --git a/src/jaxatari/games/mods/upndown_mods.py b/src/jaxatari/games/mods/upndown_mods.py index 0dd2123b5..872515282 100644 --- a/src/jaxatari/games/mods/upndown_mods.py +++ b/src/jaxatari/games/mods/upndown_mods.py @@ -1,5 +1,5 @@ from jaxatari.modification import JaxAtariModController -from jaxatari.games.mods.upndown_mod_plugins import ( +from jaxatari.games.mods.upndown.upndown_mod_plugins import ( AllowJumpBackwardsMod, RemoveStepRoadsMod, HigherPlayerSpeedMod, From e36af2a5f32fc8eb267a01272952a4a856defb21 Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Thu, 6 Nov 2025 17:09:00 +0100 Subject: [PATCH 38/76] initial commit using pong as template --- src/jaxatari/games/UpNDown.py | 663 ++++++++++++++++++++++++++++++++++ 1 file changed, 663 insertions(+) create mode 100644 src/jaxatari/games/UpNDown.py diff --git a/src/jaxatari/games/UpNDown.py b/src/jaxatari/games/UpNDown.py new file mode 100644 index 000000000..904a45bb3 --- /dev/null +++ b/src/jaxatari/games/UpNDown.py @@ -0,0 +1,663 @@ +from jax._src.pjit import JitWrapped +import os +from functools import partial +from typing import NamedTuple, Tuple +import jax.lax +import jax.numpy as jnp +import chex + +import jaxatari.spaces as spaces +from jaxatari.renderers import JAXGameRenderer +from jaxatari.rendering import jax_rendering_utils as render_utils +from jaxatari.environment import JaxEnvironment, JAXAtariAction as Action + +class UpNDownConstants(NamedTuple): + FRAME_SKIP: int = 4 + DIFFICULTIES: chex.Array = jnp.array([0, 1, 2, 3, 4, 5]) + ACTION_REPEAT_PROBS: float = 0.25 + + +# immutable state container +class UpNDownState(NamedTuple): + player_y: chex.Array + player_speed: chex.Array + score: chex.Array + difficulty: chex.Array + + + +class EntityPosition(NamedTuple): + x: jnp.ndarray + y: jnp.ndarray + width: jnp.ndarray + height: jnp.ndarray + + +class EnemyCar(NamedTuple): + position: EntityPosition + speed: chex.Array + type: chex.Array + + +class UpNDownObservation(NamedTuple): + player: EntityPosition + enemies: jnp.ndarray + score: jnp.ndarray + +class Collectible(NamedTuple): + position: EntityPosition + type: chex.Array + value: chex.Array + + +class UpNDownInfo(NamedTuple): + time: jnp.ndarray + + +class JaxUpNDown(JaxEnvironment[UpNDownState, UpNDownObservation, UpNDownInfo, UpNDownConstants]): + def __init__(self, consts: UpNDownConstants = None, reward_funcs: list[callable]=None): + consts = consts or UpNDownConstants() + super().__init__(consts) + self.renderer = UpNDownRenderer(self.consts) + if reward_funcs is not None: + reward_funcs = tuple(reward_funcs) + self.reward_funcs = reward_funcs + self.action_set = [ + Action.NOOP, + Action.FIRE, + Action.RIGHT, + Action.LEFT, + Action.RIGHTFIRE, + Action.LEFTFIRE, + ] + self.obs_size = 3*4+1+1 + + def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: + up = jnp.logical_or(action == Action.LEFT, action == Action.LEFTFIRE) + down = jnp.logical_or(action == Action.RIGHT, action == Action.RIGHTFIRE) + + acceleration = self.consts.PLAYER_ACCELERATION[state.acceleration_counter] + + touches_wall = jnp.logical_or( + state.player_y < self.consts.WALL_TOP_Y, + state.player_y + self.consts.PLAYER_SIZE[1] > self.consts.WALL_BOTTOM_Y, + ) + + player_speed = state.player_speed + + player_speed = jax.lax.cond( + jnp.logical_or(jnp.logical_not(jnp.logical_or(up, down)), touches_wall), + lambda s: jnp.round(s / 2).astype(jnp.int32), + lambda s: s, + operand=player_speed, + ) + + direction_change_up = jnp.logical_and(up, state.player_speed > 0) + player_speed = jax.lax.cond( + direction_change_up, + lambda s: 0, + lambda s: s, + operand=player_speed, + ) + direction_change_down = jnp.logical_and(down, state.player_speed < 0) + + player_speed = jax.lax.cond( + direction_change_down, + lambda s: 0, + lambda s: s, + operand=player_speed, + ) + + direction_change = jnp.logical_or(direction_change_up, direction_change_down) + acceleration_counter = jax.lax.cond( + direction_change, + lambda _: 0, + lambda s: s, + operand=state.acceleration_counter, + ) + + player_speed = jax.lax.cond( + up, + lambda s: jnp.maximum(s - acceleration, -self.consts.MAX_SPEED), + lambda s: s, + operand=player_speed, + ) + + player_speed = jax.lax.cond( + down, + lambda s: jnp.minimum(s + acceleration, self.consts.MAX_SPEED), + lambda s: s, + operand=player_speed, + ) + + new_acceleration_counter = jax.lax.cond( + jnp.logical_or(up, down), + lambda s: jnp.minimum(s + 1, 15), + lambda s: 0, + operand=acceleration_counter, + ) + + proposed_player_y = jnp.clip( + state.player_y + player_speed, + self.consts.WALL_TOP_Y + self.consts.WALL_TOP_HEIGHT - 10, + self.consts.WALL_BOTTOM_Y - 4, + ) + + # Match original timing/buffering behavior + new_player_y, new_player_speed, new_acc_counter = jax.lax.cond( + state.step_counter % 2 == 0, + lambda _: (proposed_player_y, player_speed, new_acceleration_counter), + lambda _: (state.player_y, state.player_speed, state.acceleration_counter), + operand=None, + ) + + buffer = jax.lax.cond( + jax.lax.eq(state.buffer, state.player_y), + lambda _: new_player_y, + lambda _: state.buffer, + operand=None, + ) + final_player_y = state.buffer + + return UpNDownState( + player_y=final_player_y, + player_speed=new_player_speed, + ball_x=state.ball_x, + ball_y=state.ball_y, + enemy_y=state.enemy_y, + enemy_speed=state.enemy_speed, + ball_vel_x=state.ball_vel_x, + ball_vel_y=state.ball_vel_y, + player_score=state.player_score, + enemy_score=state.enemy_score, + step_counter=state.step_counter, + acceleration_counter=new_acc_counter, + buffer=buffer, + ) + + def _ball_step(self, state: UpNDownState, action) -> UpNDownState: + ball_x = state.ball_x + state.ball_vel_x + ball_y = state.ball_y + state.ball_vel_y + + wall_bounce = jnp.logical_or( + ball_y <= self.consts.WALL_TOP_Y + self.consts.WALL_TOP_HEIGHT - self.consts.BALL_SIZE[1], + ball_y >= self.consts.WALL_BOTTOM_Y, + ) + ball_vel_y = jnp.where(wall_bounce, -state.ball_vel_y, state.ball_vel_y) + + player_paddle_hit = jnp.logical_and( + jnp.logical_and(self.consts.PLAYER_X <= ball_x, ball_x <= self.consts.PLAYER_X + self.consts.PLAYER_SIZE[0]), + state.ball_vel_x > 0, + ) + + player_paddle_hit = jnp.logical_and( + player_paddle_hit, + jnp.logical_and( + state.player_y - self.consts.BALL_SIZE[1] <= ball_y, + ball_y <= state.player_y + self.consts.PLAYER_SIZE[1] + self.consts.BALL_SIZE[1], + ), + ) + + enemy_paddle_hit = jnp.logical_and( + jnp.logical_and(self.consts.ENEMY_X <= ball_x, ball_x <= self.consts.ENEMY_X + self.consts.ENEMY_SIZE[0] - 1), + state.ball_vel_x < 0, + ) + + enemy_paddle_hit = jnp.logical_and( + enemy_paddle_hit, + jnp.logical_and( + state.enemy_y - self.consts.BALL_SIZE[1] <= ball_y, + ball_y <= state.enemy_y + self.consts.ENEMY_SIZE[1] + self.consts.BALL_SIZE[1], + ), + ) + + paddle_hit = jnp.logical_or(player_paddle_hit, enemy_paddle_hit) + + section_height = self.consts.PLAYER_SIZE[1] / 5 + + hit_position = jnp.where( + paddle_hit, + jnp.where( + player_paddle_hit, + jnp.where( + ball_y < state.player_y + section_height, + -2.0, + jnp.where( + ball_y < state.player_y + 2 * section_height, + -1.0, + jnp.where( + ball_y < state.player_y + 3 * section_height, + 0.0, + jnp.where( + ball_y < state.player_y + 4 * section_height, + 1.0, + 2.0, + ), + ), + ), + ), + jnp.where( + ball_y < state.enemy_y + section_height, + -2.0, + jnp.where( + ball_y < state.enemy_y + 2 * section_height, + -1.0, + jnp.where( + ball_y < state.enemy_y + 3 * section_height, + 0.0, + jnp.where( + ball_y < state.enemy_y + 4 * section_height, + 1.0, + 2.0, + ), + ), + ), + ), + ), + 0.0, + ) + + paddle_speed = jnp.where( + player_paddle_hit, + state.player_speed, + jnp.where( + enemy_paddle_hit, + state.enemy_speed, + 0.0, + ), + ) + + ball_vel_y = jnp.where(paddle_hit, hit_position, ball_vel_y) + + boost_triggered = jnp.logical_and( + player_paddle_hit, + jnp.logical_or( + jnp.logical_or(action == Action.LEFTFIRE, action == Action.RIGHTFIRE), + action == Action.FIRE, + ), + ) + player_max_hit = jnp.logical_and(player_paddle_hit, state.player_speed == self.consts.MAX_SPEED) + ball_vel_x = jnp.where( + jnp.logical_or(boost_triggered, player_max_hit), + state.ball_vel_x + + jnp.sign(state.ball_vel_x), + state.ball_vel_x, + ) + + ball_vel_x = jnp.where( + paddle_hit, + -ball_vel_x, + ball_vel_x, + ) + + return UpNDownState( + player_y=state.player_y, + player_speed=state.player_speed, + ball_x=ball_x.astype(jnp.int32), + ball_y=ball_y.astype(jnp.int32), + enemy_y=state.enemy_y, + enemy_speed=state.enemy_speed, + ball_vel_x=ball_vel_x.astype(jnp.int32), + ball_vel_y=ball_vel_y.astype(jnp.int32), + player_score=state.player_score, + enemy_score=state.enemy_score, + step_counter=state.step_counter, + acceleration_counter=state.acceleration_counter, + buffer=state.buffer, + ) + + def _enemy_step(self, state: UpNDownState) -> UpNDownState: + should_move = state.step_counter % 8 != 0 + + direction = jnp.sign(state.ball_y - state.enemy_y) + + new_y = state.enemy_y + (direction * self.consts.ENEMY_STEP_SIZE).astype(jnp.int32) + enemy_y = jax.lax.cond( + should_move, lambda _: new_y, lambda _: state.enemy_y, operand=None + ) + return UpNDownState( + player_y=state.player_y, + player_speed=state.player_speed, + ball_x=state.ball_x, + ball_y=state.ball_y, + enemy_y=enemy_y.astype(jnp.int32), + enemy_speed=state.enemy_speed, + ball_vel_x=state.ball_vel_x, + ball_vel_y=state.ball_vel_y, + player_score=state.player_score, + enemy_score=state.enemy_score, + step_counter=state.step_counter, + acceleration_counter=state.acceleration_counter, + buffer=state.buffer, + ) + + def _score_and_reset(self, state: UpNDownState) -> UpNDownState: + player_goal = state.ball_x < 4 + enemy_goal = state.ball_x > 156 + ball_reset = jnp.logical_or(enemy_goal, player_goal) + + player_score = jax.lax.cond( + player_goal, + lambda s: s + 1, + lambda s: s, + operand=state.player_score, + ) + enemy_score = jax.lax.cond( + enemy_goal, + lambda s: s + 1, + lambda s: s, + operand=state.enemy_score, + ) + + current_values = ( + state.ball_x.astype(jnp.int32), + state.ball_y.astype(jnp.int32), + state.ball_vel_x.astype(jnp.int32), + state.ball_vel_y.astype(jnp.int32), + ) + ball_x_final, ball_y_final, ball_vel_x_final, ball_vel_y_final = jax.lax.cond( + ball_reset, + lambda x: self._reset_ball_after_goal((state, enemy_goal)), + lambda x: x, + operand=current_values, + ) + + step_counter = jax.lax.cond( + ball_reset, + lambda s: jnp.array(0), + lambda s: s + 1, + operand=state.step_counter, + ) + + enemy_y_final = jax.lax.cond( + ball_reset, + lambda s: self.consts.BALL_START_Y.astype(jnp.int32), + lambda s: state.enemy_y.astype(jnp.int32), + operand=None, + ) + + ball_x_final = jax.lax.cond( + step_counter < 60, + lambda s: self.consts.BALL_START_X.astype(jnp.int32), + lambda s: s, + operand=ball_x_final, + ) + ball_y_final = jax.lax.cond( + step_counter < 60, + lambda s: self.consts.BALL_START_Y.astype(jnp.int32), + lambda s: s, + operand=ball_y_final, + ) + + return UpNDownState( + player_y=state.player_y, + player_speed=state.player_speed, + ball_x=ball_x_final, + ball_y=ball_y_final, + enemy_y=enemy_y_final, + enemy_speed=state.enemy_speed, + ball_vel_x=ball_vel_x_final, + ball_vel_y=ball_vel_y_final, + player_score=player_score, + enemy_score=enemy_score, + step_counter=step_counter, + acceleration_counter=state.acceleration_counter, + buffer=state.buffer, + ) + + def _reset_ball_after_goal(self, state_and_goal: Tuple[UpNDownState, bool]) -> Tuple[chex.Array, chex.Array, chex.Array, chex.Array]: + state, scored_right = state_and_goal + + ball_vel_y = jnp.where( + state.ball_y > self.consts.BALL_START_Y, + 1, + -1, + ).astype(jnp.int32) + + ball_vel_x = jnp.where( + scored_right, 1, -1 + ).astype(jnp.int32) + + return ( + self.consts.BALL_START_X.astype(jnp.int32), + self.consts.BALL_START_Y.astype(jnp.int32), + ball_vel_x.astype(jnp.int32), + ball_vel_y.astype(jnp.int32), + ) + + def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: + state = UpNDownState( + player_y=jnp.array(96).astype(jnp.int32), + player_speed=jnp.array(0.0).astype(jnp.int32), + ball_x=jnp.array(78).astype(jnp.int32), + ball_y=jnp.array(115).astype(jnp.int32), + enemy_y=jnp.array(115).astype(jnp.int32), + enemy_speed=jnp.array(0.0).astype(jnp.int32), + ball_vel_x=self.consts.BALL_SPEED[0].astype(jnp.int32), + ball_vel_y=self.consts.BALL_SPEED[1].astype(jnp.int32), + player_score=jnp.array(0).astype(jnp.int32), + enemy_score=jnp.array(0).astype(jnp.int32), + step_counter=jnp.array(0).astype(jnp.int32), + acceleration_counter=jnp.array(0).astype(jnp.int32), + buffer=jnp.array(96).astype(jnp.int32), + ) + initial_obs = self._get_observation(state) + + return initial_obs, state + + @partial(jax.jit, static_argnums=(0,)) + def step(self, state: UpNDownState, action: chex.Array) -> Tuple[UpNDownObservation, UpNDownState, float, bool, UpNDownInfo]: + previous_state = state + state = self._player_step(state, action) + state = self._enemy_step(state) + state = self._ball_step(state, action) + state = self._score_and_reset(state) + + done = self._get_done(state) + env_reward = self._get_reward(previous_state, state) + info = self._get_info(state) + observation = self._get_observation(state) + + return observation, state, env_reward, done, info + + + def render(self, state: UpNDownState) -> jnp.ndarray: + return self.renderer.render(state) + + def _get_observation(self, state: UpNDownState): + player = EntityPosition( + x=jnp.array(self.consts.PLAYER_X), + y=state.player_y, + width=jnp.array(self.consts.PLAYER_SIZE[0]), + height=jnp.array(self.consts.PLAYER_SIZE[1]), + ) + + enemy = EntityPosition( + x=jnp.array(self.consts.ENEMY_X), + y=state.enemy_y, + width=jnp.array(self.consts.ENEMY_SIZE[0]), + height=jnp.array(self.consts.ENEMY_SIZE[1]), + ) + + ball = EntityPosition( + x=state.ball_x, + y=state.ball_y, + width=jnp.array(self.consts.BALL_SIZE[0]), + height=jnp.array(self.consts.BALL_SIZE[1]), + ) + return UpNDownObservation( + player=player, + enemy=enemy, + ball=ball, + score_player=state.player_score, + score_enemy=state.enemy_score, + ) + + @partial(jax.jit, static_argnums=(0,)) + def obs_to_flat_array(self, obs: UpNDownObservation) -> jnp.ndarray: + return jnp.concatenate([ + obs.player.x.flatten(), + obs.player.y.flatten(), + obs.player.height.flatten(), + obs.player.width.flatten(), + obs.enemy.x.flatten(), + obs.enemy.y.flatten(), + obs.enemy.height.flatten(), + obs.enemy.width.flatten(), + obs.ball.x.flatten(), + obs.ball.y.flatten(), + obs.ball.height.flatten(), + obs.ball.width.flatten(), + obs.score_player.flatten(), + obs.score_enemy.flatten() + ] + ) + + def action_space(self) -> spaces.Discrete: + return spaces.Discrete(6) + + def observation_space(self) -> spaces: + return spaces.Dict({ + "player": spaces.Dict({ + "x": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), + "y": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), + "width": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), + "height": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), + }), + "enemy": spaces.Dict({ + "x": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), + "y": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), + "width": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), + "height": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), + }), + "ball": spaces.Dict({ + "x": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), + "y": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), + "width": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), + "height": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), + }), + "score_player": spaces.Box(low=0, high=21, shape=(), dtype=jnp.int32), + "score_enemy": spaces.Box(low=0, high=21, shape=(), dtype=jnp.int32), + }) + + def image_space(self) -> spaces.Box: + return spaces.Box( + low=0, + high=255, + shape=(210, 160, 3), + dtype=jnp.uint8 + ) + + @partial(jax.jit, static_argnums=(0,)) + def _get_info(self, state: UpNDownState, ) -> UpNDownInfo: + return UpNDownInfo(time=state.step_counter) + + @partial(jax.jit, static_argnums=(0,)) + def _get_reward(self, previous_state: UpNDownState, state: UpNDownState): + return (state.player_score - state.enemy_score) - ( + previous_state.player_score - previous_state.enemy_score + ) + + @partial(jax.jit, static_argnums=(0,)) + def _get_done(self, state: UpNDownState) -> bool: + return jnp.logical_or( + jnp.greater_equal(state.player_score, 21), + jnp.greater_equal(state.enemy_score, 21), + ) + +class UpNDownRenderer(JAXGameRenderer): + def __init__(self, consts: UpNDownConstants = None): + super().__init__() + self.consts = consts or UpNDownConstants() + self.config = render_utils.RendererConfig( + game_dimensions=(210, 160), + channels=3, + #downscale=(84, 84) + ) + self.jr = render_utils.JaxRenderingUtils(self.config) + # 1. Create procedural assets for both walls + wall_sprite_top = self._create_wall_sprite(self.consts.WALL_TOP_HEIGHT) + wall_sprite_bottom = self._create_wall_sprite(self.consts.WALL_BOTTOM_HEIGHT) + + # 2. Update asset config to include both walls + asset_config = self._get_asset_config(wall_sprite_top, wall_sprite_bottom) + sprite_path = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/UpNDown" + + # 3. Make a single call to the setup function + ( + self.PALETTE, + self.SHAPE_MASKS, + self.BACKGROUND, + self.COLOR_TO_ID, + self.FLIP_OFFSETS + ) = self.jr.load_and_setup_assets(asset_config, sprite_path) + + def _create_wall_sprite(self, height: int) -> jnp.ndarray: + """Procedurally creates an RGBA sprite for a wall of given height.""" + wall_color_rgba = (*self.consts.SCORE_COLOR, 255) # e.g., (236, 236, 236, 255) + wall_shape = (height, self.consts.WIDTH, 4) + wall_sprite = jnp.tile(jnp.array(wall_color_rgba, dtype=jnp.uint8), (*wall_shape[:2], 1)) + return wall_sprite + + def _get_asset_config(self, wall_sprite_top: jnp.ndarray, wall_sprite_bottom: jnp.ndarray) -> list: + """Returns the declarative manifest of all assets for the game, including both wall sprites.""" + return [ + {'name': 'background', 'type': 'background', 'file': 'background.npy'}, + {'name': 'player', 'type': 'single', 'file': 'player.npy'}, + {'name': 'enemy', 'type': 'single', 'file': 'enemy.npy'}, + {'name': 'ball', 'type': 'single', 'file': 'ball.npy'}, + {'name': 'player_digits', 'type': 'digits', 'pattern': 'player_score_{}.npy'}, + {'name': 'enemy_digits', 'type': 'digits', 'pattern': 'enemy_score_{}.npy'}, + # Add the procedurally created sprites to the manifest + {'name': 'wall_top', 'type': 'procedural', 'data': wall_sprite_top}, + {'name': 'wall_bottom', 'type': 'procedural', 'data': wall_sprite_bottom}, + ] + + @partial(jax.jit, static_argnums=(0,)) + def render(self, state): + raster = self.jr.create_object_raster(self.BACKGROUND) + + player_mask = self.SHAPE_MASKS["player"] + raster = self.jr.render_at(raster, self.consts.PLAYER_X, state.player_y, player_mask) + + enemy_mask = self.SHAPE_MASKS["enemy"] + raster = self.jr.render_at(raster, self.consts.ENEMY_X, state.enemy_y, enemy_mask) + + ball_mask = self.SHAPE_MASKS["ball"] + raster = self.jr.render_at(raster, state.ball_x, state.ball_y, ball_mask) + + # --- Stamp Walls and Score (using the same color/ID) --- + score_color_tuple = self.consts.SCORE_COLOR # (236, 236, 236) + score_id = self.COLOR_TO_ID[score_color_tuple] + + # Draw walls (using separate sprites for top and bottom) + raster = self.jr.render_at(raster, 0, self.consts.WALL_TOP_Y, self.SHAPE_MASKS["wall_top"]) + raster = self.jr.render_at(raster, 0, self.consts.WALL_BOTTOM_Y, self.SHAPE_MASKS["wall_bottom"]) + + # Stamp Score using the label utility + player_digits = self.jr.int_to_digits(state.player_score, max_digits=2) + enemy_digits = self.jr.int_to_digits(state.enemy_score, max_digits=2) + + # Note: The logic for single/double digits is complex for a jitted function. + player_digit_masks = self.SHAPE_MASKS["player_digits"] # Assumes single color + enemy_digit_masks = self.SHAPE_MASKS["enemy_digits"] # Assumes single color + + is_player_single_digit = state.player_score < 10 + player_start_index = jax.lax.select(is_player_single_digit, 1, 0) + player_num_to_render = jax.lax.select(is_player_single_digit, 1, 2) + player_render_x = jax.lax.select(is_player_single_digit, + 120 + 16 // 2, + 120) + + raster = self.jr.render_label_selective(raster, player_render_x, 3, player_digits, player_digit_masks, player_start_index, player_num_to_render, spacing=16) + + is_enemy_single_digit = state.enemy_score < 10 + enemy_start_index = jax.lax.select(is_enemy_single_digit, 1, 0) + enemy_num_to_render = jax.lax.select(is_enemy_single_digit, 1, 2) + enemy_render_x = jax.lax.select(is_enemy_single_digit, + 10 + 16 // 2, + 10) + + raster = self.jr.render_label_selective(raster, enemy_render_x, 3, enemy_digits, enemy_digit_masks, enemy_start_index, enemy_num_to_render, spacing=16) + + return self.jr.render_from_palette(raster, self.PALETTE) \ No newline at end of file From 38a3427566ada2080ed96916e8e2a3976318eec2 Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Fri, 7 Nov 2025 17:52:12 +0100 Subject: [PATCH 39/76] rough design of potential car movements --- src/jaxatari/games/UpNDown.py | 402 +++++++++++++--------------------- 1 file changed, 157 insertions(+), 245 deletions(-) diff --git a/src/jaxatari/games/UpNDown.py b/src/jaxatari/games/UpNDown.py index 904a45bb3..af62fe461 100644 --- a/src/jaxatari/games/UpNDown.py +++ b/src/jaxatari/games/UpNDown.py @@ -15,28 +15,44 @@ class UpNDownConstants(NamedTuple): FRAME_SKIP: int = 4 DIFFICULTIES: chex.Array = jnp.array([0, 1, 2, 3, 4, 5]) ACTION_REPEAT_PROBS: float = 0.25 - - -# immutable state container -class UpNDownState(NamedTuple): - player_y: chex.Array - player_speed: chex.Array - score: chex.Array - difficulty: chex.Array + MAX_SPEED: int = 4 + JUMP_FRAMES: int = 10 + LANDING_ZONE: int = 15 + FIRST_ROAD_LENGTH: int = 4 + SECOND_ROAD_LENGTH: int = 4 + FIRST_TRACK_CORNERS_X: chex.Array = jnp.array([20, 50, 80, 100]) #get actual values + FIRST_TRACK_CORNERS_Y: chex.Array = jnp.array([20, 50, 80, 100]) #get actual values + SECOND_TRACK_CORNERS_X: chex.Array = jnp.array([20, 50, 80, 100]) #get actual values + SECOND_TRACK_CORNERS_Y: chex.Array = jnp.array([20, 50, 80, 100]) #get actual values +# immutable state container class EntityPosition(NamedTuple): x: jnp.ndarray y: jnp.ndarray width: jnp.ndarray height: jnp.ndarray - -class EnemyCar(NamedTuple): +class Car(NamedTuple): position: EntityPosition speed: chex.Array type: chex.Array + current_road: chex.Array + road_index_A: chex.Array + road_index_B: chex.Array + direction_x: chex.Array + +class UpNDownState(NamedTuple): + score: chex.Array + difficulty: chex.Array + road_index: chex.Array + jump_cooldown: chex.Array + is_jumping: chex.Array + is_on_road: chex.Array + player_car: Car + + class UpNDownObservation(NamedTuple): @@ -65,270 +81,186 @@ def __init__(self, consts: UpNDownConstants = None, reward_funcs: list[callable] self.action_set = [ Action.NOOP, Action.FIRE, - Action.RIGHT, - Action.LEFT, - Action.RIGHTFIRE, - Action.LEFTFIRE, + Action.UPFIRE, + Action.UP, + Action.DOWN, + Action.DOWNFIRE, ] self.obs_size = 3*4+1+1 - def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: - up = jnp.logical_or(action == Action.LEFT, action == Action.LEFTFIRE) - down = jnp.logical_or(action == Action.RIGHT, action == Action.RIGHTFIRE) - - acceleration = self.consts.PLAYER_ACCELERATION[state.acceleration_counter] - - touches_wall = jnp.logical_or( - state.player_y < self.consts.WALL_TOP_Y, - state.player_y + self.consts.PLAYER_SIZE[1] > self.consts.WALL_BOTTOM_Y, + @partial(jax.jit, static_argnums=(0,)) + def _car_past_corner(self, car: Car, state: UpNDownState) -> chex.Array: + direction_change_A = jnp.logical_or(jnp.logical_and(car.speed > 0, car.position.y + car.speed > self.consts.FIRST_TRACK_CORNERS_Y[car.road_index+1]), jnp.logical_and(car.speed < 0, car.position.y + car.speed < self.consts.FIRST_TRACK_CORNERS_Y[car.road_index])) + direction_change_B = jnp.logical_or(jnp.logical_and(car.speed > 0, car.position.y + car.speed > self.consts.SECOND_TRACK_CORNERS_Y[car.road_index+1]), jnp.logical_and(car.speed < 0, car.position.y + car.speed < self.consts.SECOND_TRACK_CORNERS_Y[car.road_index])), + + road_index_A = jax.lax.cond(jnp.logical_and(direction_change_A, car.speed > 0), + lambda s: s + 1, + lambda s: s, + operand=car.road_index_A, ) - - player_speed = state.player_speed - - player_speed = jax.lax.cond( - jnp.logical_or(jnp.logical_not(jnp.logical_or(up, down)), touches_wall), - lambda s: jnp.round(s / 2).astype(jnp.int32), + road_index_A = jax.lax.cond(jnp.logical_and(direction_change_A, car.speed < 0), + lambda s: s - 1, lambda s: s, - operand=player_speed, + operand=car.road_index_A, ) - direction_change_up = jnp.logical_and(up, state.player_speed > 0) - player_speed = jax.lax.cond( - direction_change_up, - lambda s: 0, + road_index_B = jax.lax.cond(jnp.logical_and(direction_change_B, car.speed > 0), + lambda s: s + 1, lambda s: s, - operand=player_speed, + operand=car.road_index_B, ) - direction_change_down = jnp.logical_and(down, state.player_speed < 0) - - player_speed = jax.lax.cond( - direction_change_down, - lambda s: 0, + road_index_B = jax.lax.cond(jnp.logical_and(direction_change_B, car.speed < 0), + lambda s: s - 1, lambda s: s, - operand=player_speed, + operand=car.road_index_B, ) + current_road_length_A = self.consts.FIRST_ROAD_LENGTH + current_road_length_B = self.consts.SECOND_ROAD_LENGTH - direction_change = jnp.logical_or(direction_change_up, direction_change_down) - acceleration_counter = jax.lax.cond( - direction_change, - lambda _: 0, + road_index_A = jax.lax.cond(road_index_A < 0, + lambda s: current_road_length_A - 1, lambda s: s, - operand=state.acceleration_counter, + operand=road_index_A, ) - player_speed = jax.lax.cond( - up, - lambda s: jnp.maximum(s - acceleration, -self.consts.MAX_SPEED), + road_index_A = jax.lax.cond(road_index_A >= current_road_length_A, + lambda s: 0, lambda s: s, - operand=player_speed, + operand=road_index_A, ) - player_speed = jax.lax.cond( - down, - lambda s: jnp.minimum(s + acceleration, self.consts.MAX_SPEED), + road_index_B = jax.lax.cond(road_index_B < 0, + lambda s: current_road_length_B - 1, lambda s: s, - operand=player_speed, + operand=road_index_B, ) - new_acceleration_counter = jax.lax.cond( - jnp.logical_or(up, down), - lambda s: jnp.minimum(s + 1, 15), + road_index_B = jax.lax.cond(road_index_B >= current_road_length_B, lambda s: 0, - operand=acceleration_counter, + lambda s: s, + operand=road_index_B, ) - proposed_player_y = jnp.clip( - state.player_y + player_speed, - self.consts.WALL_TOP_Y + self.consts.WALL_TOP_HEIGHT - 10, - self.consts.WALL_BOTTOM_Y - 4, - ) + return road_index_A, road_index_B + + @partial(jax.jit, static_argnums=(0,)) + def _landing_in_water(self, state: UpNDownState, new_position_x: chex.Array, new_position_y: chex.Array) -> chex.Array: + road_A_x = ((new_position_y - self.consts.FIRST_TRACK_CORNERS_Y[state.player_car.road_index_A]) / (self.consts.FIRST_TRACK_CORNERS_Y[state.player_car.road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_Y[state.player_car.road_index_A])) * (self.consts.FIRST_TRACK_CORNERS_X[state.player_car.road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_X[state.player_car.road_index_A]) + self.consts.FIRST_TRACK_CORNERS_X[state.player_car.road_index_A] + road_B_x = ((new_position_y - self.consts.SECOND_TRACK_CORNERS_Y[state.player_car.road_index_B]) / (self.consts.SECOND_TRACK_CORNERS_Y[state.player_car.road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_Y[state.player_car.road_index_B])) * (self.consts.SECOND_TRACK_CORNERS_X[state.player_car.road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[state.player_car.road_index_B]) + self.consts.SECOND_TRACK_CORNERS_X[state.player_car.road_index_B] + distance_to_road_A = jnp.abs(new_position_x - road_A_x) + distance_to_road_B = jnp.abs(new_position_x - road_B_x) + landing_in_Water = jnp.logical_and(distance_to_road_A > self.consts.LANDING_ZONE, distance_to_road_B > self.consts.LANDING_ZONE) + between_roads = jnp.logical_and(new_position_x > jnp.minimum(road_A_x, road_B_x), new_position_x < jnp.maximum(road_A_x, road_B_x)) + return landing_in_Water, between_roads - # Match original timing/buffering behavior - new_player_y, new_player_speed, new_acc_counter = jax.lax.cond( - state.step_counter % 2 == 0, - lambda _: (proposed_player_y, player_speed, new_acceleration_counter), - lambda _: (state.player_y, state.player_speed, state.acceleration_counter), - operand=None, - ) + def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: + up = jnp.logical_or(action == Action.UP, action == Action.UPFIRE) + down = jnp.logical_or(action == Action.DOWN, action == Action.DOWNFIRE) + jump = jnp.logical_or(action == Action.FIRE, action == Action.UPFIRE, action == Action.DOWNFIRE) - buffer = jax.lax.cond( - jax.lax.eq(state.buffer, state.player_y), - lambda _: new_player_y, - lambda _: state.buffer, - operand=None, - ) - final_player_y = state.buffer - return UpNDownState( - player_y=final_player_y, - player_speed=new_player_speed, - ball_x=state.ball_x, - ball_y=state.ball_y, - enemy_y=state.enemy_y, - enemy_speed=state.enemy_speed, - ball_vel_x=state.ball_vel_x, - ball_vel_y=state.ball_vel_y, - player_score=state.player_score, - enemy_score=state.enemy_score, - step_counter=state.step_counter, - acceleration_counter=new_acc_counter, - buffer=buffer, - ) - def _ball_step(self, state: UpNDownState, action) -> UpNDownState: - ball_x = state.ball_x + state.ball_vel_x - ball_y = state.ball_y + state.ball_vel_y + player_speed = state.player_car.speed - wall_bounce = jnp.logical_or( - ball_y <= self.consts.WALL_TOP_Y + self.consts.WALL_TOP_HEIGHT - self.consts.BALL_SIZE[1], - ball_y >= self.consts.WALL_BOTTOM_Y, + player_speed = jax.lax.cond( + jnp.logical_and(state.player_car.speed < self.consts.MAX_SPEED, up), + lambda s: s + 1, + lambda s: s, + operand=player_speed, ) - ball_vel_y = jnp.where(wall_bounce, -state.ball_vel_y, state.ball_vel_y) - player_paddle_hit = jnp.logical_and( - jnp.logical_and(self.consts.PLAYER_X <= ball_x, ball_x <= self.consts.PLAYER_X + self.consts.PLAYER_SIZE[0]), - state.ball_vel_x > 0, + player_speed = jax.lax.cond( + jnp.logical_and(state.player_car.speed > -self.consts.MAX_SPEED, down), + lambda s: s - 1, + lambda s: s, + operand=player_speed, ) - player_paddle_hit = jnp.logical_and( - player_paddle_hit, - jnp.logical_and( - state.player_y - self.consts.BALL_SIZE[1] <= ball_y, - ball_y <= state.player_y + self.consts.PLAYER_SIZE[1] + self.consts.BALL_SIZE[1], - ), - ) - enemy_paddle_hit = jnp.logical_and( - jnp.logical_and(self.consts.ENEMY_X <= ball_x, ball_x <= self.consts.ENEMY_X + self.consts.ENEMY_SIZE[0] - 1), - state.ball_vel_x < 0, + is_jumping = jnp.logical_or(jnp.logical_and(state.is_jumping, state.jump_cooldown > 0), jnp.logical_and(is_on_road, jnp.logical_and(player_speed > 0, state.jump_cooldown == 0))) + jump_cooldown = jax.lax.cond( + state.jump_cooldown > 0, + lambda s: s - 1, + lambda s: jnp.cond(jnp.logical_and(is_jumping), + lambda _: state.JUMP_FRAMES, + lambda _: 0, + operand=None), + operand=state.jump_cooldown, ) - enemy_paddle_hit = jnp.logical_and( - enemy_paddle_hit, - jnp.logical_and( - state.enemy_y - self.consts.BALL_SIZE[1] <= ball_y, - ball_y <= state.enemy_y + self.consts.ENEMY_SIZE[1] + self.consts.BALL_SIZE[1], - ), - ) - paddle_hit = jnp.logical_or(player_paddle_hit, enemy_paddle_hit) - - section_height = self.consts.PLAYER_SIZE[1] / 5 - - hit_position = jnp.where( - paddle_hit, - jnp.where( - player_paddle_hit, - jnp.where( - ball_y < state.player_y + section_height, - -2.0, - jnp.where( - ball_y < state.player_y + 2 * section_height, - -1.0, - jnp.where( - ball_y < state.player_y + 3 * section_height, - 0.0, - jnp.where( - ball_y < state.player_y + 4 * section_height, - 1.0, - 2.0, - ), - ), - ), - ), - jnp.where( - ball_y < state.enemy_y + section_height, - -2.0, - jnp.where( - ball_y < state.enemy_y + 2 * section_height, - -1.0, - jnp.where( - ball_y < state.enemy_y + 3 * section_height, - 0.0, - jnp.where( - ball_y < state.enemy_y + 4 * section_height, - 1.0, - 2.0, - ), - ), - ), - ), - ), - 0.0, - ) - paddle_speed = jnp.where( - player_paddle_hit, - state.player_speed, - jnp.where( - enemy_paddle_hit, - state.enemy_speed, - 0.0, - ), - ) - ball_vel_y = jnp.where(paddle_hit, hit_position, ball_vel_y) + ##check if player is on the the road + is_on_road = ~state.is_jumping - boost_triggered = jnp.logical_and( - player_paddle_hit, - jnp.logical_or( - jnp.logical_or(action == Action.LEFTFIRE, action == Action.RIGHTFIRE), - action == Action.FIRE, - ), - ) - player_max_hit = jnp.logical_and(player_paddle_hit, state.player_speed == self.consts.MAX_SPEED) - ball_vel_x = jnp.where( - jnp.logical_or(boost_triggered, player_max_hit), - state.ball_vel_x - + jnp.sign(state.ball_vel_x), - state.ball_vel_x, - ) + road_index_A, road_index_B = self._car_past_corner(state.player_car, state) - ball_vel_x = jnp.where( - paddle_hit, - -ball_vel_x, - ball_vel_x, + direction_change = jax.lax.cond( + jnp.logical_and(is_on_road, jnp.logical_or(jnp.logical_and(jnp.equal(road_index_A, state.player_car.road_index_A)) , state.player_car.current_road == 0), (jnp.logical_and(jnp.equal(road_index_B, state.player_car.road_index_B)) , state.player_car.current_road == 1) ), + lambda s: False, + lambda s: True, + operand=None, ) - return UpNDownState( - player_y=state.player_y, - player_speed=state.player_speed, - ball_x=ball_x.astype(jnp.int32), - ball_y=ball_y.astype(jnp.int32), - enemy_y=state.enemy_y, - enemy_speed=state.enemy_speed, - ball_vel_x=ball_vel_x.astype(jnp.int32), - ball_vel_y=ball_vel_y.astype(jnp.int32), - player_score=state.player_score, - enemy_score=state.enemy_score, - step_counter=state.step_counter, - acceleration_counter=state.acceleration_counter, - buffer=state.buffer, + + car_direction_x = jax.lax.cond( + direction_change, + lambda s: jax.lax.cond(state.player_car.current_road == 0, + lambda s: self.consts.FIRST_TRACK_CORNERS_X[road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_X[road_index_A], + lambda s: self.consts.SECOND_TRACK_CORNERS_X[road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[road_index_B], + operand=None), + lambda s: s, + operand=state.player_car.direction_x, ) + + is_landing = jnp.logical_and(state.jump_cooldown == 1, jump_cooldown == 0) - def _enemy_step(self, state: UpNDownState) -> UpNDownState: - should_move = state.step_counter % 8 != 0 + ##calculate new position with speed (TODO: calculate better speed) + player_y = state.player_car.position.y + player_speed + player_x = state.player_car.position.x + player_speed * car_direction_x - direction = jnp.sign(state.ball_y - state.enemy_y) + landing_in_Water, between_roads, road_A_x, road_B_x = self._landing_in_water(state, player_x, player_y) + landing_in_Water = jnp.logical_and(is_landing, landing_in_Water) + - new_y = state.enemy_y + (direction * self.consts.ENEMY_STEP_SIZE).astype(jnp.int32) - enemy_y = jax.lax.cond( - should_move, lambda _: new_y, lambda _: state.enemy_y, operand=None + current_road = jax.lax.cond( + landing_in_Water, + lambda s: 2, + lambda s: jax.lax.cond( + is_on_road, + lambda s: state.player_car.current_road, + lambda s: jax.lax.cond( + jnp.abs(player_x - road_A_x) < jnp.abs(player_x - road_B_x), + lambda s: 0, + lambda s: 1, + operand=None, + ), + operand=None, + ), + operand=None, ) return UpNDownState( - player_y=state.player_y, - player_speed=state.player_speed, - ball_x=state.ball_x, - ball_y=state.ball_y, - enemy_y=enemy_y.astype(jnp.int32), - enemy_speed=state.enemy_speed, - ball_vel_x=state.ball_vel_x, - ball_vel_y=state.ball_vel_y, - player_score=state.player_score, - enemy_score=state.enemy_score, - step_counter=state.step_counter, - acceleration_counter=state.acceleration_counter, - buffer=state.buffer, + score=state.score, + difficulty=state.difficulty, + road_index=state.road_index, + jump_cooldown=jump_cooldown, + is_jumping=is_jumping, + is_on_road=is_on_road, + player_car=Car( + position=EntityPosition( + x=player_x, + y=player_y, + width=state.player_car.position.width, + height=state.player_car.position.height, + ), + speed=player_speed, + direction_x=car_direction_x, + current_road=current_road, + road_index_A=road_index_A, + road_index_B=road_index_B, + type=state.player_car.type, + ), ) def _score_and_reset(self, state: UpNDownState) -> UpNDownState: @@ -405,26 +337,6 @@ def _score_and_reset(self, state: UpNDownState) -> UpNDownState: buffer=state.buffer, ) - def _reset_ball_after_goal(self, state_and_goal: Tuple[UpNDownState, bool]) -> Tuple[chex.Array, chex.Array, chex.Array, chex.Array]: - state, scored_right = state_and_goal - - ball_vel_y = jnp.where( - state.ball_y > self.consts.BALL_START_Y, - 1, - -1, - ).astype(jnp.int32) - - ball_vel_x = jnp.where( - scored_right, 1, -1 - ).astype(jnp.int32) - - return ( - self.consts.BALL_START_X.astype(jnp.int32), - self.consts.BALL_START_Y.astype(jnp.int32), - ball_vel_x.astype(jnp.int32), - ball_vel_y.astype(jnp.int32), - ) - def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: state = UpNDownState( player_y=jnp.array(96).astype(jnp.int32), From 35a7dbca43eb607bd1c97cea1aa19a739638d0f0 Mon Sep 17 00:00:00 2001 From: shaik05 Date: Thu, 13 Nov 2025 12:07:26 +0100 Subject: [PATCH 40/76] added basic interface template --- src/jaxatari/games/upndown_interface.py | 53 +++++++++++++++++++++++++ 1 file changed, 53 insertions(+) create mode 100644 src/jaxatari/games/upndown_interface.py diff --git a/src/jaxatari/games/upndown_interface.py b/src/jaxatari/games/upndown_interface.py new file mode 100644 index 000000000..68f8c76fe --- /dev/null +++ b/src/jaxatari/games/upndown_interface.py @@ -0,0 +1,53 @@ +import jax +import jax.numpy as jnp +import matplotlib.pyplot as plt + +from jaxatari.environment import JAXAtariAction as Action +from upndown import JaxUpNDown, UpNDownConstants # <-- your game file + +def visualize_frame(frame: jnp.ndarray): + """Render an RGB frame using matplotlib.""" + plt.imshow(frame.astype(jnp.uint8)) + plt.axis("off") + plt.show(block=False) + plt.pause(0.05) + plt.clf() + + +def main(): + # Initialize environment + env = JaxUpNDown(UpNDownConstants()) + + # Reset environment + obs, state = env.reset() + print("Initial observation:", obs) + + # Display initial render + frame = env.render(state) + visualize_frame(frame) + + # Create a random key for sampling actions + key = jax.random.PRNGKey(0) + + # Run for 50 steps + for step in range(50): + key, subkey = jax.random.split(key) + # Choose a random action from action space + action = jax.random.choice(subkey, jnp.arange(len(env.action_set))) + + obs, state, reward, done, info = env.step(state, action) + + # Render and display + frame = env.render(state) + visualize_frame(frame) + + print(f"Step {step}: action={env.action_set[int(action)]}, reward={reward}, done={done}") + + if bool(done): + print("Game over — resetting environment.") + obs, state = env.reset() + + plt.close() + +if __name__ == "__main__": + main() From 9eb102d64bf724c90bb2d277dafc855841267967 Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Thu, 13 Nov 2025 20:36:52 +0100 Subject: [PATCH 41/76] add partial backrounds and car sprites --- .../sprites/up_n_down/backround/backround1.npy | Bin 0 -> 102944 bytes .../up_n_down/backround/backround10.npy | Bin 0 -> 25548 bytes .../up_n_down/backround/backround11.npy | Bin 0 -> 37180 bytes .../up_n_down/backround/backround12.npy | Bin 0 -> 43808 bytes .../up_n_down/backround/backround13.npy | Bin 0 -> 45096 bytes .../sprites/up_n_down/backround/backround2.npy | Bin 0 -> 37948 bytes .../sprites/up_n_down/backround/backround3.npy | Bin 0 -> 37328 bytes .../sprites/up_n_down/backround/backround4.npy | Bin 0 -> 46944 bytes .../sprites/up_n_down/backround/backround5.npy | Bin 0 -> 34848 bytes .../sprites/up_n_down/backround/backround6.npy | Bin 0 -> 34624 bytes .../sprites/up_n_down/backround/backround7.npy | Bin 0 -> 47560 bytes .../sprites/up_n_down/backround/backround8.npy | Bin 0 -> 41132 bytes .../sprites/up_n_down/backround/backround9.npy | Bin 0 -> 31748 bytes 13 files changed, 0 insertions(+), 0 deletions(-) create mode 100644 src/jaxatari/games/sprites/up_n_down/backround/backround1.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/backround/backround10.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/backround/backround11.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/backround/backround12.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/backround/backround13.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/backround/backround2.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/backround/backround3.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/backround/backround4.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/backround/backround5.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/backround/backround6.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/backround/backround7.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/backround/backround8.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/backround/backround9.npy diff --git a/src/jaxatari/games/sprites/up_n_down/backround/backround1.npy b/src/jaxatari/games/sprites/up_n_down/backround/backround1.npy new file mode 100644 index 0000000000000000000000000000000000000000..6c353b610ae66a21a8991791f5875675ae42bf48 GIT binary patch literal 102944 zcmeI4O^cjG6otpFn{2b3-6%{JK_lW1aHoji!c_<+VnB>c)Qum1!XG@Qk={H_IDO_; z)m!!U5yEih+`4t|t*7hl^o;rU>+in$=KBw>{Q zbNAcbi(j6BXh z;V>4)S~E5qQ|#{{&r0T)y+_u&H$D6O!`S-o$HJKY+1cLimC^l#u`t$}x!IXwzb&4X z%(>bOSpV#LwfTmz_1=$#u_xuF(M(TC-P`}mZaIDwW)jCG+ zceuw?&7*5__hY>3XKSDJy1H8b=>5W&{#jV{Ou61YW5zR|tC>$}H$UUUSQraq{`)b` zjIXQFn)SH4cfG8q^?v8vW8AZMO=>^xuQc6Q z*8XQnWX9E7vCwPZFc!vov31P5B4fwek6+z;g7L>Z5Bph3DYE9Mk zswX!;<5hm;GoPzfe)`q?)%C1L=`i-F{C6Jr%owlEh(6aV-FiLqDILbbn9HbQivAuL zV^3;k%4fc#K7OvI)+_&*=b>L!e)`;pzS3Op>H}kHjjDk@^U+tD>s_t#D_`xW)+=A- zSH8+mpZVx3&GoJhW2e8@p=PMgMENSe@>PEN%tv2ou6K193u9sI^yiiC>qyP8x9`^L z#>(zee++)l!drzuf*56yRp4OW6yZNi@)qY{@^jXQ7SG78Ws?Yk|eAV@8zc3cY z!dU(N7++_cX={D_JS|`U|9ykQJ$7lcXw4M!UD`ZVzhSKXvr^6c()=(M#(Fa!#@1tu z&t|2`kcy|z^{zgLG481}Yf$m@xxTgLeyw%udYR9yxAl7FQ#y?09#hOJhLj)1!dMuq z|34jShP{1$K3COlz1*+5p81qM++!-HT2pnss!#du{;lIxeXj4;$Nbgx%%^l1YyYg| z8LBgIedbr1`Bgl9u6K193u9!8+GUFIdt{6=ay4f}?dD@V^Hp^tW-1OnYdczcYW?ZdcE2&jD@i(qnE%KXXI+m zNUi7R+A3c8%(pjvxW`uSNu8nc*?;SJHy`U+U5ByLe~Z9nLd7e8b%s>^VJwU(Mymdb zjH!D;`J5T!=~JtC`mW}Ft@Y@7&S7kHO=_lYJ&b35<-76B=W6cPT92+Lj6J^R%+;By zc>1nZ`&a$Z^@Oo7rkJZTy#ixu2FmA*7*GFwnxCT`HS-<6fBZ_{)yF)ozK-7iz?iCi zbWNx-%CIAb2YV& z{V<-stC?@KJ}}mMPj@DaXMM_N{OnrwL*La~_aDaE|18NpQ!>N#^j*z-qjeYyW4)PM zfib>rd(UL+{Os4&te@-Y_trfaQ$5jlb?=#MouBhyJ@j48e7$uT+x}U}GbS_2XS~wP z*IS3NFc!xA-yiEev(fo^FL8hRu4ca0y2m}n8rie*8Lu?+?M<_P_DA2<%-33nvH6~r zJYzDWe8wxye64jD3uCRBV!k~v#@8u(+?!VY(sy<1d5vIy znXjtD*!<5*o-vtGKI4^UzN!vmVT{aGnd16AGS+)$oZ;45^-JH?)&90%j5V`oDBpb@t`^V!SwDSOGoR8k z++(bfJu9E_N;BWyH2Y_N^j*z-N{6x2ds5A8cAw0!Uiz+PKBdE07z<;kzdy#;jXNX8 zGvD5{>X*K&)x2iB$5=CaRzBmEX1={?_Rs$4yPEl^!&v>_PUlQ#(`1JA(RVfTQHQZG z<}$^2GDW=tW2{;ANZ-}mZ*Q9OV1M*o&3s`@|DA#}TV0bG_Cw#*%ooPO7@2UHVm$MO zu`os^$c(F*Z$-w~181Oo#w*Qyd(-Tn{n2-I7(2bkI77~SZ(8+B-_>C(jFB0>hUmLG zjD;~W6UO|XVRMGA=FBfmyYpiF9_}%9=Iq(kD*xVncV3JSV|iAp*9e($br=g{VNAdO za$i?&ygGx`eRm#i{2uNxw`Vtgb%s>^?mXQ1FqUVf%ZwWz#=@A(j2piKV{T7wyn2?b z?z{7F2Dn_b)7z<-nhSd5nrk{n?8LAog_Er5=U)8Vj&+vXs z)mZhbGg9lf_N((z>s5bheHeRkuV$oXytS|DullM#l|PJyF~vw2%lj?Gi1Mopsr4i7 zG1X($S3RoyYW>!}s=w;1{#1UoK8)#mP|dKnugX^&xazC=RDQKy z`C%-Kg|YR2rmfDf>Z^N1Dvf%*iKx`0Vg)Cyk0uf{8zQ4zFys| z7T-QzudbHco1G%%Lo1RVm@o;mzT5V;@PbE-5l`qcrux6?DgXP-;MIR zzkj~zdMUN`$?ZDVHs;mv_Iq)!z8bIXd|tm^+&J5=;ZWn9%`V$D3URrP}TcXv-V?U&U(m$f~6iT7h>t%N`A8t1~8&Doy4#QR|F z|7Ta1r*;j_W$p6fkXW!5tflUqLwM@iGbYx$<}49QU6ZjKFX2&FjbJUFYj&@6aZ(}{ ztOaYqn)Y`9Tfa)&dtEi`nv8Yrn!QAyfR_RBhqdz$;H`(%yh&efcq z^%(aw_fvcOj~!<}xtjg59^;8jRwuOv zYr&eaL3UJaby90htZ@d_i`w(~|KW0V-+uJTTCO+T-?vvopO5qB>b~a*)^?wveS5E- NFPAOuSI?iD{|}RAGR6P^ literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/backround/backround11.npy b/src/jaxatari/games/sprites/up_n_down/backround/backround11.npy new file mode 100644 index 0000000000000000000000000000000000000000..06b3d1675bf65113926d862950b295f1d1d9aa46 GIT binary patch literal 37180 zcmeI*Pl}vX6o&C~>K=rem4?nB!63Q-XOT>TIFiP+Bp^n*H4~FTAY=pXAS-Z(W!FHd z{&m6eY2^Q-6Qzg=AO=a0TV+dg>o)!FvTFV41awjc5D z)6HhH_wn%E!~Kwt4aos zpIg}+vond--v!T(dyM9h&C9uu_-tL-d-B+?Q_r(bMeEDCa_*F8$3Kf4%I0fyE;|F) zAJ6;nJjz|abhh6t+rRbY-0k~ZG&A$Hv5sWfd01c0m2;6JIumiUUpZIKMUF&{nMeDr z!8v=*ax^3JwQ+V{u6G_$Ud~;A=gfB>(ac=Gbhh6t+rRbZ&?r#a!yCG ztW(yTFVQ)B%Ca-C-h6Frzgc!3)|<=aTz%)OGZE{M>&@kIuAI}Eh;_*I=5je#&go3V zI^=qD`7StT&)FGSzcjYrEZeX3<`FN!IeW~qGqB!#ZEU|;b{^K7M_kU`{;njN+0t?y z$+G=gZys?uSI+5>tV2=mdU=h`*<&>$~N9XJ@%g(@h^R==4X4!dI zZ=U6HZhe2dr6XC^A?wYvT+WqqkyBBA7oD@`-ZOIlxv~9b*?z1yZ{<1M=j<`d&cJ%} zwXywX*?CxR-pb|N@;_U3B+EKvy?HB_bLE^4m2>qsGLaKH5^=Pjd2gQLJ{Qd-%AJAh z&DX~En`P%=y?Jjg=a%=`sv}v}A?wY1b2(Sem2I}i6WADhd$=|5X_qLp>X{mjSaa;}^!=WhRv^W5I2y%W|v{Z_Rt3 ow{^d|&rN@y^EGUBEUNcd#;x;->ZhL9KZlt4N&o-= literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/backround/backround12.npy b/src/jaxatari/games/sprites/up_n_down/backround/backround12.npy new file mode 100644 index 0000000000000000000000000000000000000000..ea76f49ffd82c144453aa08bd26c71bdfbb90aaa GIT binary patch literal 43808 zcmeI4O=}cE5Qf*Q7q4D~y{$?DM*IMe9y|mT5kx$NpeqI>N}^u;073Exyn4zn@E?xZ z!$9Z3Y)wt~Oi!(^Fzmc}tE;Q3pPdZ}=GT+wkDtE0H2XaJvUt5bes#2XFkgKAw6oZn zFW$UA`gpW|xcC0(^>X#~NBi%NmuFu;e!Kr+d8V)5zqK{rxqEMGe&@#4{M)?8zss}P ztVyrK-w&Ib*TKP+#&fY&uBAUUZPiQP(o^e?^`xw=?&*)Mz4K)J*!7m)&+3Q1rAybB zo3A!+*1YJmU(09w)V0+QeM?WRU-4G=`v>wsek`By)LdU%i#PE$CWFy3o3 zl1IG7;%#?#x5>G>FfnAz*u3S;YPELASwzSkeSYFi&$lt(XW~t~iMNZtM;UWA>-M_4 zceSZ;<$j*ls5#w}&$rYY;y(JkhM0*bXY~1rw>r-$4`XSpc=MfC^m(OcU*kUd{PcYD zyPuqml}+vdx|&Tc}1UBdiFK$qt8#zH{biEWo+Hg z=7#igc^k5}F1g8Dmo+&t(dU)Zdx^Et=O^Cu`>Yc0gRVKmn|SNWN{Kk}Cf-VL)HR2A z6K`EvDG?{$#9IlDy5*aprjecnzeO{&K;msF) ze&Vg)?-#tBl!_N`zVnJcuTtFf%@cio(VlOdci(f{ly#@4H=CSm>^Y3h8~e+d)oLyC z{ck$j{KT8@`+=$PEZ%(Q7j0hR&3C?>SM=u_d1Ze?>2-e_Zb*;)axeN`&3vUbc`Kct z`PjeZGv3n7w`rRFvp@P?&3vV`c$@a$Q{{P+AIoRFrJ1j^7H{G$Ww10~%(rR0rL$Yz zv)|~N{ZX?&`d-a^xw!VNYjpur5=CibT+w9L%ix2ukcK-i2aXwSiJa5)l+GjpXvmR>kCf-seES@Ys&GoPQXuPop z_HX%&w>0zB)~uKHSiU!&`MjF@iMR95H_nsus;#a1ET8pT@$|i#`-wO4mNG;>Qm@ZF z^NBa{mU2WsQm@ZF^NF{9ylwwDXlh(MbJk$_-gxHoYVJqg=$F>4m-#K<8_#@R&HZX? z*3Wt@-y6?-Ud{c)+xq7l=T%$tyjhRsd*hkUtGSi`ng``Q+tz^6UH2WICQ)-ps$v z&#%vJ<`=W&`>)Tx+|KUazx{mvb#|wpym&qxADtXe$Hz~nwz8;Q7qs8*N{`ab= z>+0&!V%x>_Zj0T;W4@@~rN*`x*4q}l^cY{dHfJlI4?QQvh4NWH8qfPn)3tEMzVRHC z&v>Pok6PWYe3g%W?YdOXR9|_HjHe&f%tsx~$9R>Ge(l&deE^NTJP{ zIZMG6a}I^Gwm*MG^Re$k)8;JAJR^rgK=irAc=nn;bu^y$Q!`)RbYJEf z&n>EXuIm11d^8{P*RIXk>bc2tjB=skmCyQp)8@>aReSk{is&;YxB*SIm^MUIWuSGZ2kMQ=yQkh?4|MIl`x2;t@ z`bxLGzTElD+4^UVHhWKQn47-RyuUVW&dga3X3d#7GiPfCqR$n^=ji3O8frcCmEKmX zM6c9&rus~uy;VNrQ){&z`bwu>Z`=9J*?MNS-P_H4*;nN=zH6=4OJC`(_oM2UGS5^W>2qJoXMERMt(U&iUGGQLH)s3JTh&+PlN-j< zSDN>irp;MDI3tJVtodBoCu;Uh`HU}JtNqYdy0ravHP2L^=(BIiXME{e?T5b7rR~?8 zHP6@F=R|H;pT5$(ziVyI`pFqNG-s>-XUM)N-F4qm*H`do(Y~m7wXfQJ zwg0H!bwA8m8P82}5ame4GkIvO*vJ6l;8DU%~?0|DSGaTBjrcywH1gh6opMo!z;KmNT#qvC@3jV8W2LFph6ZgB2ttf8)*1>30mIOB*n`1hpe-g z-I?7PpGd3C@wt0v=H72-6f3`9y?y!m-Mz`T$@kgCYV%<|dp4c@__~o4o& z_4}Lk#p>?!7t5>7>h|-^r{(9>t$z4)F`q8Z&gRo6kLJ^#)BpInKbcIn&2|0f`L?F( z{Cu+Azqq!2FU9X~em&mqUmRoKJB8mH&hGBDtylGye&3q)WPbX!YqNH^_mPJ#n(Qza z{o1uzGi$jRubDO0ko!t*&H0>~Hf!njHL`qaY;nD2&8(4$RNAbWwNz{|uURu|WFnO| zYi2DKTg+?L%o>?UrOldIOT`xR*325ebBJoziaMH)@yu7dHfzWCHPw5v6V)nzZN6DE zYq^*%$nLKl+PM7o<6mTrytcEuXJtJc7KDA1=U*>Pd}>F_@((~ z&8+2O)~uN|v$kWPsYUb*VtkXQxhubkYkODu=tp(id5kr#q#NVs)t@5$2V*F&QR8#>?oh{N;BWk+N_zi zTnv`P8f(J$nto|ot%rV84_%L(tVL_Xc-Bh!j32vJ>!BajL$BYg4R?mJ_hd);j8~fZ za%;0@)^agd5^H>KbJb|5anyS0M|JLXrDQF-cE+<7%4dA(TCJCUROhabSsTw8%HES5 zYe)(>p(xMnXhl%Cu`BQF`jEvKI3cGsvh*CI(7Zb z+IY`U_MYr0pYcjFU*FoSnYD(&zS&|vvsMmk>|Mi1b60*7*Y?hQU3C}M+Rph{9OuqB z`qqc+leb+Z-KC;8H=tnj4b=79gtQ}*pD_eZttd+|edr4-L&v>Po zFRCM~sktbhb7MSx>)3ms>PbJUyY|DZ#roN`-YK6aI~xn>Dl6mboZfj5llb Ru*TkAUOw3Nc`<)f{{<4M{W1Um literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/backround/backround3.npy b/src/jaxatari/games/sprites/up_n_down/backround/backround3.npy new file mode 100644 index 0000000000000000000000000000000000000000..ee7c16619682eb58b3f417f6eb3b8f43549aaa7d GIT binary patch literal 37328 zcmeI5&1w`u6ouQZ8((2|qcDO2+5O~2>J8s?aSBi?ls?=AM^9&`on7ed^Z32?P$K3%|BkP zzOK$L-(RiHm)++t&OWb~*PpLHoqbte>xa*tEM`Z?$BWt1M~m67*)6{AH%-&_ugkwD z?a*~`@t}=)4Xq__crOm=Pt3jQ<~n4}eKCE`y8C@`+s|{#eVdcM?`JZn{Ss^Y_o)Vk z#oGV(a1HxeQVpl4O&fBB)apIgLh|g!hUBWT;b-ua`zO}I_phX9K&;)Z{hpFdu_o5W z;%`c}Sd;f1%JDhX!rY`!f_@I@_BLh@#ai8GYEEZ?SQBeQJnCGln~#wF9pdH?k85)b zeUJ8@Yu4KIz1owzVsAck-hMNlJ+*ws4{5P>bH9fAmAgkayqfiKJbjNAYhta>6WQr| z-93Ge7Heg*Mi#u9tZ+R25bb4+YgxI`_iC;uz58r<`^k9r)AAYb(PC|P?>xQBA(m`- zHEZB_`XO4ZiM2jIVy$c*`g<9Azu9cskc;bcF8VQb=()yRtla5)HP^dq+S@P2vtO3a z_z*4D$~rU2hF6mnj;9|}i#4&<=Pf2%oU1g}`hDr{Sx;@6^?Eh?!SVFR)UjD(u2yaI zy_)OQroH`OJo{n!jE||sTHfzavf$QBwkEz94Uhhz{;nie?pOemwHTK5pEq$-%da<=PAI38u%V&H}TC9!x z4ka62O;$Lbe$HB~iM3RGmd2X(T+{dJ(7y9;dA*wJb3FZ$bndLNmsaoTdo|bdYVUfC z=X#dU_?)#^+x3cQTqxR0ncxyiTC2O&^ zd#+h5aIYEf)z*CUOV(mdtfk^ttcf+TcDvuP@*T!L(oe0eeCXGtOF!4R7i7os89%nR z@}Xao7Hh-r(Aeye9nMd`CN0*)S}LAvV~u^Y&JFsp^_$<1+y2G*`(Ew$WUaP(t-Ylm zTeF{>pMK3+td08)B|BoRX}8R=H>GQ9Z|TR@>>uZ+U$fpaw*6jntzB*0>(sjT{tngm E53}cxUjP6A literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/backround/backround4.npy b/src/jaxatari/games/sprites/up_n_down/backround/backround4.npy new file mode 100644 index 0000000000000000000000000000000000000000..65d9d322f490b9d5430979b2e88fc2e795ab12b9 GIT binary patch literal 46944 zcmeI4Pm3H?6vf-E8^1#DT47uSqJm2ol8qvQ3s*9jkpv`>8Fk~IFX7hjI;I#N3_rR& z@63Hw(_M1{P36?N_rCk?`*n5bN%GIv-+lGX_wVg~-~DlTd3E#S_2KjV;m_Zm9WM5V zpWa^odVTrkhqu=+ucrH7UjBS@_3r-7tIJ=m-sulN`}AV}?33pg`{y5B?El(7;Me=R z-R^dJz4`Zxyb3dT{rZDj=QVoXT-&W*%~$I^l(TAI)qJ%t)_3(|*H`P)=lvR3J^VNFcESyb@R5=Z2$2sC$sny&V?`l=w^|}7q`D(qbKF_#s zp2OPu+rK})eaf%K+16*UHdkD4>+@Cn3uh1eTqQ?Jt8>`8AI`$r8jL>`&Nxf;{Hgj} zZ?ty%FEdvua<}e6U!0D_Cp=c!r8<`ID6Qi(Kt6(^ZnR#`~K0b z$M{y~sa)IhY-{`Fxw=}-!>z}7Rex)KI6MBism__)xLVcUnjg->*&0l5g){eg!g$V_ zKD8Uqd}`KHn)ynzo{Fc>d{_5!#&wluA1a&ps-RjI;tM#h+bA8R-Z=B8f?S8edYJTgZUEO+&Z*|&L`7N$pRXlxHw{_;O)q2(Z zxxVJ^hqJq%H*oHA^J@21JbhQUb>^_s@m1o>}}~4CldE z`mR>>=(}3=!+I*7el?%YVAjTJ^(vDxQ8dpY^E2SvZ@R2xrItoX)vW^ZlfJ#?R^YJ({|5J;u{_HS3+aZqKt* z>(iY(#UKTFqVKfRnS-IrUR@!Xg48Nakv{m^%{+yBz_*-toI=5v#C=D8}L@k+Cv(yAVP zRZschES!<4V_I>f{Bz}u`{HcqkJ5ABtGV~tFZ1cUn)Rqh*JFL^#u@KZkM7H@$9UBr z{ZrHKei*N`+ONu|uj(m3oIUK{yQ{e>U!5a;SG()0`l_Dt!&x{}j7`4cNcrc&8TY~2 z&_6e=_Cw#*Y96eo;$6R&GuCzcV!W%X>ofnTHh_Yjw!L+|@_)c0_z4)@D=`mScZrFDCrjb6vyH{;ol@)>{b zTJ4vzLd}St!p(8`mSDje&OsiKR0>qu?s%CI&}y#d_h)nWe{PhG2d(RcOe`G&L8__>pNuH=UK^j*z*RUOX4*)ayI zTroeKohN5JN3x=P#w*QwuI}Yb^`U(B#d!MB9{*~d^j%#&k8tLlXLa67<&zt(L*Lb` z=jw14&U%=0xng`c+XiPm*B*A97}j_7TAXoDZl8(w-VdC2RafU9&b;%h&U>x=Dp$;} z_8rc`8JRk!!`U`Dsh%tIU%c4eR$i><>fYyBrN=4P?JMzP-#?Fie&KBFd7kR~KJU5{Kb%H_*h zkzdJf8WjS48ml&HMAWA0Bl-x}UR))#lTB_Nt%#`hGH-_p{H}>znoR z>f`nLVzrxpz5Kda-Og{mEWfR8?Wc>E^ZsP9nD;N9&HLYd!k@=o*KOsw`ggkZJZER! z*7JDwX6&h}_Pf57uAy+Y8@+pZI2+=u-t)B{-+KMYn(r=ei9fOy|6OgVZ&S#nHTFFB zT5D&S0v|cg&{}`U+H}65$SGO7`=(-TXy@ggtdX@^@6_exldb3R=56_gdh*=GdUIvG z@7}B3KeFcQ-Znhph2l3o)7dT_tN|jjT`r zBdMc@o0Qm`tl|EtHIO*eqkaFqkUFJ~L)Vk6-S6Mf_4n+T%9zA?_D+cv&$Pt(`x&p0*!ZYxf&!WdZ%Q-lM-}YkJx` zStDy2W3sNXq^*ycHK~D&+F4s)mvf|!*7I5SvFf9LP1Z0M>)FvC^~Pr7d91eeyruOq zAF|fYH`LmD^vB+6HpXkVrl+lwHL|8LChHnY+WM4OlbXn=oxS$DRWBhw>8^(9`w(Bp3?ect==~jduz4HTE5;{Ymff9dy!iY^R?Qg^U1A0<$Dc# zm#cQ~{yg8NJeZHwPT8}in0?k-%Ii=r*5081Ar-M4 c&l+39kRR%y9g2&cM?J5N`&#w9@!Ed>0ih7rLI3~& literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/backround/backround6.npy b/src/jaxatari/games/sprites/up_n_down/backround/backround6.npy new file mode 100644 index 0000000000000000000000000000000000000000..0642e3cff1f1ba56d35e673c11980f4dfe92951d GIT binary patch literal 34624 zcmeI2zi!k(5XMbQ!z;Kmplc!}6qFPw4Ty$<3Z2LyB1H+h0~+8dXsPm|CN8Y3zHs)A zcV_?ji)7u-d3HSW&Cg>e%AdFI-@N@IePPy6m=+x_}}(rw!A^UeO-etZ4# zW`D68zJI;FI_&PgKYZDK-QC$w&t7cWle5!Jd-`nC{%#liJZ_q%ACK$5ullLu_V#%{ zb>QA9-{IVx|J-?uFZIX$b2i48+S2%;{+M5Ayt(~T=SRkdPs|JLsW`(qZGVgxYR~0A zcRuv`rSG59XX?C!vo-59Xjppvq4`g7Ge4B z=FF{g$=X-{TK&$?n?B@lb*Y)N5RHbtu1;%QYwAbNj^DFuWXr9YoK63oZ`4Xtr!nT5 zJ~j2LX`X57vL>F!S}pY>XZO#|QSY^Wk~2OhW3W%o$Qe2FW57Zu%lpqF-`rxx)UV|C zvzR*N#)w8fG4&&7>E^KpDJ-w?T=&M*ucUbvQNPlzLOV0dyE!8QP{-(-9su%ie=AH9A!@8xmrcNcs zlIn&2nl+O%zgeoOd8}B-jnvl6lbn&WWBie``~T%}HPY0j#8|EtOV*Y(=QPijXyfj& z)TxBnt`eQafec zF*%d!g8nIW%Hi6mhty74134?_{Y=Ug`ja!hkNGfw+(>PoYgqYko^hU8vE{}kXXMP~ z!D>y;$e9&eZd`Ik&Ria>*5s@l&X57=?*ZtK@hoji&af`2ZCQ5>v88&Ve@qR?8RuCH z-pCm_BWL9@5Tkc^4SX(2ZH%XBYtA!iUFfe_*BWD?o>JRVGjdkX%#?D4{^YD47A$*- z+(>Q9+&ItbWmw7?`fG9*!WrfzwKeNttcUlDv>)iNsVO;o*dI-OQ^rDWq_$@6_Db@iul~K1vVY`Ezn{6@BPlui2e0iZE&u=k literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/backround/backround7.npy b/src/jaxatari/games/sprites/up_n_down/backround/backround7.npy new file mode 100644 index 0000000000000000000000000000000000000000..259a34f810ef7a6bb7bc68472c654c527964e2f2 GIT binary patch literal 47560 zcmeI2&uSDw7{u4BPqDXEiHAT$@R)-)#fyk12`1t}5=qpHdhr#!iibRbZ@6X%i}A;V z+M4e9x0JA$n)>z?2AxWu2&Up{;B`p)9Z;_K@8dA8T?eqR>Z@GGRcK+%7 z==9Cm`SD3}{prz%i<7JC7w?Zgo?PjBk00+X_a7bXEf4PREx#{2{JFbWEY{8E;j2IS z^tcC?mk-uAKiuz|UYDPlH#NO2XIwkiLEqKP+cnMo;QHvhnt4;x;cS?5oZOHD`mSc) z)O0utXG1U<&cazZ8-{@)_F26L*@LQAKJ&)bUC%SFlk1@GYUa&dtMzkzs$Tib8(W97 zVa`mK3stZDvAN>>;Vhhyi9U2Vn+9h*pLnn7yPA1J>sp*~-CPHKS2J&Jx^%tnc-D`- zU*Rn0*$@ndvv3y9{{P+c&3C~NdsB0co4q$bGn+oOThH;-%ZXWY%>u?s%YBSZA7tX?28yB^&g|l#0o2jyYFpcsuBm?ME6ug3dgU{(wjPQzu9NGa?`q~vO>;f$ zkG|5S ztXGTtFj?@h%8eOJ3&u|AxIvu&6oH?9t6 z9%qNYf2_NFcrTejpStgQ&PUDpUA^H*=u>k(`qXYc$5S)U)!{6hk*PNAa>e?&aK^sd_QJWwyuLMi<9zg~ z-TK<&RUZ4Jo+@W*UG!beedt@WKh8&=+O4lWUgfbr>TtH3?@h%8eOHqsrNdb`a~W&u z6<74bnaA1bufuhhkGhx4WWDlPPhC1*Zx+S;AJw4VJlPx-7@nt9aW?E3%ga$jm|_db=@ zvw!9(pY=*Jk2;)%vxW)9k@CqEbvO%W;q2dcK|W7>Ug*1;d9}6bhx3)zt32ha`ROx{ zI_Fu=vxW(B<7#q69nQj8IJ^D5$>)vF4}DiNuT-n~nO|C8J5TjPpYzZ4JX7nU?`rN- zsaEqdzqGz~p6Z7_=MQJs?@hH&+&5PYd_DpkGyZLW*#;3=u@kD<-2*TcXc=mXALthm*MRCUF+_Hs^@)hwVPjB z&;CmD)Oh933tr1 Syy8gt;VhhO!?5B=`F{cQg9AcR-q*Nn{ogFNp)N1F=Df6|#s?7KlVPumQHY05|XnxS@%KMk-p< z)z#IH>Bk>wZyM|*RWB>Onn^w7>Vd?wXbefw7V_;0f^BhaIxn3Bv)i0gb zJcpKIq&m&{fid(k$#-Fn9qAkx17lz;&hHP>@0`s~@*GyIDRB+oJL>_jGg$9aK!nv z$QW0zWR26Lhc$^toR5!2we~Tte#x4TMlJeEG~#@GGy!9Aeyj4)yaZp4AIX|e4PXq6 zNtiNyT4c#LXPAEE&m>*HN%GUGc2WAnc` zQXNAyY9W?tp|L|WXtj@pXwpKgh9epiQ-e|&qxEX+m>SiXOQRW$iCHrkYx6utdC}Nl z4CgNw0~)Sqj4Qv-dMka*r{)%YDIXd;r6#2?rqPAQq|~T}+*AvVow7zS*2Z~E!xfDI zV>o}g7@)jpY*+q}^;-K_NX;$8YB-`XIW<^{Fd@t%f$fYMPs}2r(8!LE3=PjIHEBrneok?iZNQR#`dkFHGhp}G^Qjyz!>^iY97HD z7z1NvG7$T@(D$6ij(tBd^Tq088eM2i%$jP*!<4^f3)bG-gR=eD~HeM(fqszW3RhzeY0}QUz!(@y%_SJi!`R_) zehj%%bC5E>l(7(fPO&L#s3CXH`hhX@vDDvDFb2lJSRMva>x_5h;5;w&tWga9de#r` T$~=7%jG>Rk=v7>fo*DZKUKXOw literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/backround/backround9.npy b/src/jaxatari/games/sprites/up_n_down/backround/backround9.npy new file mode 100644 index 0000000000000000000000000000000000000000..ca78ccff5956da850b147227da97d14abebac977 GIT binary patch literal 31748 zcmeI*L2evX5XSLg%N>wrH-OB-65enCcA#t!VudVX!~zjy1Dhn1%~f4F=1^@rbX|Gaz0pWj?xZLe=`uC`x)akYKD{fJ+m zZZ@08>H6x&{~mi@Z{B?Nxa4K6daFn5emUCTTeqH8wqE_zBX++$!H zORu+{EL)HJ)w8UhTv>ZMpKE5vEU0H$Ke@72S(`GE&6evYpSgQwt+F;{!|Y^v@|nAj zlr?+E&SLcSZ{L3Txa76=Jl4$Gl5^U7tv#z#pEvsaDr<*#=&7?5onKiy{#lJat0)V- z_0i{7pKHheS?Qf6{moXSyFOz5dUMaL*{i*%XIZ~XW9MOi>RHyWH&@o`bFFvACjZJ> zWo^nr@B7GE^PSOqR%`p)pKLwqS=Mi9UNdXu5anaqQjxj@>u6PQO+@W!CJ`UevR!-)P+0 zZ+5@ym!4N;ZQOTglnwPPSJsaIe>F?bYHh#PY-IPle(8QM`E#xJ?3R8O`dPmmt&ep5 zBGxazC!qT4gQD*raE( z<@%AbX6@4qKK8EPxpCB2RImG8ubyT7B0i_AS*yKAy6Yp>@7%bxe)rp-dY1KzxUzOR z?@*t$*@<-5N337Om9@%Rl&LvA%9i>_S+i$*QO~k|r{<{sY<>EqE$>*FkYqj@C zcYVbAotl@fS3mnxk9eP*_t)qrSJn>y=H+Zo&F??Q5B)w{pC0a4kGQgS$WAmXb!UHS zwmvyqU$kF6;!Dn&vvn4c?)r%JTbp~=r@#HF%VteozmeXd)*khyo@LKlo1=PL>(O67 zb-A*3_#TF{$+ELwo1=PL>(O67b-A)uSu-}A{Jzqar zfAuJ9A6mcPv%mCiWzUk=_V@fOx7M58AJv!jbA8s|^_8{5_dK$*b(XSzz1j2S*7Nm~ z^;fT~Ro2YN9Ge~4ed?99L)M&4G&|SJ?(fZ>FSnkrU&NkQS*xs>i8+q4rLIq9t+Hk& z=GZJntWRaFvSudcILel~K2IrY_MT;DDL?g}rx4B4`9!?W=5$5gm!?OvQNOhQTl;Z7 n5l=bW-&@ZfZGSWy_0jt4U9a_+pCvZek%)KCkM literal 0 HcmV?d00001 From 75efded29f1200f4ab2b1e7dacad69c24188b184 Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Fri, 14 Nov 2025 23:02:04 +0100 Subject: [PATCH 42/76] create first running version doing nothing --- src/jaxatari/games/UpNDown.py | 575 ---- src/jaxatari/games/jax_upndown.py | 2490 ++--------------- .../up_n_down/backround/backround1.npy | Bin 102944 -> 0 bytes .../up_n_down/backround/backround10.npy | Bin 25548 -> 0 bytes .../up_n_down/backround/backround11.npy | Bin 37180 -> 0 bytes .../up_n_down/backround/backround12.npy | Bin 43808 -> 0 bytes .../up_n_down/backround/backround13.npy | Bin 45096 -> 0 bytes .../up_n_down/backround/backround2.npy | Bin 37948 -> 0 bytes .../up_n_down/backround/backround3.npy | Bin 37328 -> 0 bytes .../up_n_down/backround/backround4.npy | Bin 46944 -> 0 bytes .../up_n_down/backround/backround5.npy | Bin 34848 -> 0 bytes .../up_n_down/backround/backround6.npy | Bin 34624 -> 0 bytes .../up_n_down/backround/backround7.npy | Bin 47560 -> 0 bytes .../up_n_down/backround/backround8.npy | Bin 41132 -> 0 bytes .../up_n_down/backround/backround9.npy | Bin 31748 -> 0 bytes 15 files changed, 181 insertions(+), 2884 deletions(-) delete mode 100644 src/jaxatari/games/UpNDown.py delete mode 100644 src/jaxatari/games/sprites/up_n_down/backround/backround1.npy delete mode 100644 src/jaxatari/games/sprites/up_n_down/backround/backround10.npy delete mode 100644 src/jaxatari/games/sprites/up_n_down/backround/backround11.npy delete mode 100644 src/jaxatari/games/sprites/up_n_down/backround/backround12.npy delete mode 100644 src/jaxatari/games/sprites/up_n_down/backround/backround13.npy delete mode 100644 src/jaxatari/games/sprites/up_n_down/backround/backround2.npy delete mode 100644 src/jaxatari/games/sprites/up_n_down/backround/backround3.npy delete mode 100644 src/jaxatari/games/sprites/up_n_down/backround/backround4.npy delete mode 100644 src/jaxatari/games/sprites/up_n_down/backround/backround5.npy delete mode 100644 src/jaxatari/games/sprites/up_n_down/backround/backround6.npy delete mode 100644 src/jaxatari/games/sprites/up_n_down/backround/backround7.npy delete mode 100644 src/jaxatari/games/sprites/up_n_down/backround/backround8.npy delete mode 100644 src/jaxatari/games/sprites/up_n_down/backround/backround9.npy diff --git a/src/jaxatari/games/UpNDown.py b/src/jaxatari/games/UpNDown.py deleted file mode 100644 index af62fe461..000000000 --- a/src/jaxatari/games/UpNDown.py +++ /dev/null @@ -1,575 +0,0 @@ -from jax._src.pjit import JitWrapped -import os -from functools import partial -from typing import NamedTuple, Tuple -import jax.lax -import jax.numpy as jnp -import chex - -import jaxatari.spaces as spaces -from jaxatari.renderers import JAXGameRenderer -from jaxatari.rendering import jax_rendering_utils as render_utils -from jaxatari.environment import JaxEnvironment, JAXAtariAction as Action - -class UpNDownConstants(NamedTuple): - FRAME_SKIP: int = 4 - DIFFICULTIES: chex.Array = jnp.array([0, 1, 2, 3, 4, 5]) - ACTION_REPEAT_PROBS: float = 0.25 - MAX_SPEED: int = 4 - JUMP_FRAMES: int = 10 - LANDING_ZONE: int = 15 - FIRST_ROAD_LENGTH: int = 4 - SECOND_ROAD_LENGTH: int = 4 - FIRST_TRACK_CORNERS_X: chex.Array = jnp.array([20, 50, 80, 100]) #get actual values - FIRST_TRACK_CORNERS_Y: chex.Array = jnp.array([20, 50, 80, 100]) #get actual values - SECOND_TRACK_CORNERS_X: chex.Array = jnp.array([20, 50, 80, 100]) #get actual values - SECOND_TRACK_CORNERS_Y: chex.Array = jnp.array([20, 50, 80, 100]) #get actual values - - - -# immutable state container -class EntityPosition(NamedTuple): - x: jnp.ndarray - y: jnp.ndarray - width: jnp.ndarray - height: jnp.ndarray - -class Car(NamedTuple): - position: EntityPosition - speed: chex.Array - type: chex.Array - current_road: chex.Array - road_index_A: chex.Array - road_index_B: chex.Array - direction_x: chex.Array - -class UpNDownState(NamedTuple): - score: chex.Array - difficulty: chex.Array - road_index: chex.Array - jump_cooldown: chex.Array - is_jumping: chex.Array - is_on_road: chex.Array - player_car: Car - - - - -class UpNDownObservation(NamedTuple): - player: EntityPosition - enemies: jnp.ndarray - score: jnp.ndarray - -class Collectible(NamedTuple): - position: EntityPosition - type: chex.Array - value: chex.Array - - -class UpNDownInfo(NamedTuple): - time: jnp.ndarray - - -class JaxUpNDown(JaxEnvironment[UpNDownState, UpNDownObservation, UpNDownInfo, UpNDownConstants]): - def __init__(self, consts: UpNDownConstants = None, reward_funcs: list[callable]=None): - consts = consts or UpNDownConstants() - super().__init__(consts) - self.renderer = UpNDownRenderer(self.consts) - if reward_funcs is not None: - reward_funcs = tuple(reward_funcs) - self.reward_funcs = reward_funcs - self.action_set = [ - Action.NOOP, - Action.FIRE, - Action.UPFIRE, - Action.UP, - Action.DOWN, - Action.DOWNFIRE, - ] - self.obs_size = 3*4+1+1 - - @partial(jax.jit, static_argnums=(0,)) - def _car_past_corner(self, car: Car, state: UpNDownState) -> chex.Array: - direction_change_A = jnp.logical_or(jnp.logical_and(car.speed > 0, car.position.y + car.speed > self.consts.FIRST_TRACK_CORNERS_Y[car.road_index+1]), jnp.logical_and(car.speed < 0, car.position.y + car.speed < self.consts.FIRST_TRACK_CORNERS_Y[car.road_index])) - direction_change_B = jnp.logical_or(jnp.logical_and(car.speed > 0, car.position.y + car.speed > self.consts.SECOND_TRACK_CORNERS_Y[car.road_index+1]), jnp.logical_and(car.speed < 0, car.position.y + car.speed < self.consts.SECOND_TRACK_CORNERS_Y[car.road_index])), - - road_index_A = jax.lax.cond(jnp.logical_and(direction_change_A, car.speed > 0), - lambda s: s + 1, - lambda s: s, - operand=car.road_index_A, - ) - road_index_A = jax.lax.cond(jnp.logical_and(direction_change_A, car.speed < 0), - lambda s: s - 1, - lambda s: s, - operand=car.road_index_A, - ) - - road_index_B = jax.lax.cond(jnp.logical_and(direction_change_B, car.speed > 0), - lambda s: s + 1, - lambda s: s, - operand=car.road_index_B, - ) - road_index_B = jax.lax.cond(jnp.logical_and(direction_change_B, car.speed < 0), - lambda s: s - 1, - lambda s: s, - operand=car.road_index_B, - ) - current_road_length_A = self.consts.FIRST_ROAD_LENGTH - current_road_length_B = self.consts.SECOND_ROAD_LENGTH - - road_index_A = jax.lax.cond(road_index_A < 0, - lambda s: current_road_length_A - 1, - lambda s: s, - operand=road_index_A, - ) - - road_index_A = jax.lax.cond(road_index_A >= current_road_length_A, - lambda s: 0, - lambda s: s, - operand=road_index_A, - ) - - road_index_B = jax.lax.cond(road_index_B < 0, - lambda s: current_road_length_B - 1, - lambda s: s, - operand=road_index_B, - ) - - road_index_B = jax.lax.cond(road_index_B >= current_road_length_B, - lambda s: 0, - lambda s: s, - operand=road_index_B, - ) - - return road_index_A, road_index_B - - @partial(jax.jit, static_argnums=(0,)) - def _landing_in_water(self, state: UpNDownState, new_position_x: chex.Array, new_position_y: chex.Array) -> chex.Array: - road_A_x = ((new_position_y - self.consts.FIRST_TRACK_CORNERS_Y[state.player_car.road_index_A]) / (self.consts.FIRST_TRACK_CORNERS_Y[state.player_car.road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_Y[state.player_car.road_index_A])) * (self.consts.FIRST_TRACK_CORNERS_X[state.player_car.road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_X[state.player_car.road_index_A]) + self.consts.FIRST_TRACK_CORNERS_X[state.player_car.road_index_A] - road_B_x = ((new_position_y - self.consts.SECOND_TRACK_CORNERS_Y[state.player_car.road_index_B]) / (self.consts.SECOND_TRACK_CORNERS_Y[state.player_car.road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_Y[state.player_car.road_index_B])) * (self.consts.SECOND_TRACK_CORNERS_X[state.player_car.road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[state.player_car.road_index_B]) + self.consts.SECOND_TRACK_CORNERS_X[state.player_car.road_index_B] - distance_to_road_A = jnp.abs(new_position_x - road_A_x) - distance_to_road_B = jnp.abs(new_position_x - road_B_x) - landing_in_Water = jnp.logical_and(distance_to_road_A > self.consts.LANDING_ZONE, distance_to_road_B > self.consts.LANDING_ZONE) - between_roads = jnp.logical_and(new_position_x > jnp.minimum(road_A_x, road_B_x), new_position_x < jnp.maximum(road_A_x, road_B_x)) - return landing_in_Water, between_roads - - def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: - up = jnp.logical_or(action == Action.UP, action == Action.UPFIRE) - down = jnp.logical_or(action == Action.DOWN, action == Action.DOWNFIRE) - jump = jnp.logical_or(action == Action.FIRE, action == Action.UPFIRE, action == Action.DOWNFIRE) - - - - player_speed = state.player_car.speed - - player_speed = jax.lax.cond( - jnp.logical_and(state.player_car.speed < self.consts.MAX_SPEED, up), - lambda s: s + 1, - lambda s: s, - operand=player_speed, - ) - - player_speed = jax.lax.cond( - jnp.logical_and(state.player_car.speed > -self.consts.MAX_SPEED, down), - lambda s: s - 1, - lambda s: s, - operand=player_speed, - ) - - - is_jumping = jnp.logical_or(jnp.logical_and(state.is_jumping, state.jump_cooldown > 0), jnp.logical_and(is_on_road, jnp.logical_and(player_speed > 0, state.jump_cooldown == 0))) - jump_cooldown = jax.lax.cond( - state.jump_cooldown > 0, - lambda s: s - 1, - lambda s: jnp.cond(jnp.logical_and(is_jumping), - lambda _: state.JUMP_FRAMES, - lambda _: 0, - operand=None), - operand=state.jump_cooldown, - ) - - - - - ##check if player is on the the road - is_on_road = ~state.is_jumping - - road_index_A, road_index_B = self._car_past_corner(state.player_car, state) - - direction_change = jax.lax.cond( - jnp.logical_and(is_on_road, jnp.logical_or(jnp.logical_and(jnp.equal(road_index_A, state.player_car.road_index_A)) , state.player_car.current_road == 0), (jnp.logical_and(jnp.equal(road_index_B, state.player_car.road_index_B)) , state.player_car.current_road == 1) ), - lambda s: False, - lambda s: True, - operand=None, - ) - - - car_direction_x = jax.lax.cond( - direction_change, - lambda s: jax.lax.cond(state.player_car.current_road == 0, - lambda s: self.consts.FIRST_TRACK_CORNERS_X[road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_X[road_index_A], - lambda s: self.consts.SECOND_TRACK_CORNERS_X[road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[road_index_B], - operand=None), - lambda s: s, - operand=state.player_car.direction_x, - ) - - is_landing = jnp.logical_and(state.jump_cooldown == 1, jump_cooldown == 0) - - ##calculate new position with speed (TODO: calculate better speed) - player_y = state.player_car.position.y + player_speed - player_x = state.player_car.position.x + player_speed * car_direction_x - - landing_in_Water, between_roads, road_A_x, road_B_x = self._landing_in_water(state, player_x, player_y) - landing_in_Water = jnp.logical_and(is_landing, landing_in_Water) - - - current_road = jax.lax.cond( - landing_in_Water, - lambda s: 2, - lambda s: jax.lax.cond( - is_on_road, - lambda s: state.player_car.current_road, - lambda s: jax.lax.cond( - jnp.abs(player_x - road_A_x) < jnp.abs(player_x - road_B_x), - lambda s: 0, - lambda s: 1, - operand=None, - ), - operand=None, - ), - operand=None, - ) - return UpNDownState( - score=state.score, - difficulty=state.difficulty, - road_index=state.road_index, - jump_cooldown=jump_cooldown, - is_jumping=is_jumping, - is_on_road=is_on_road, - player_car=Car( - position=EntityPosition( - x=player_x, - y=player_y, - width=state.player_car.position.width, - height=state.player_car.position.height, - ), - speed=player_speed, - direction_x=car_direction_x, - current_road=current_road, - road_index_A=road_index_A, - road_index_B=road_index_B, - type=state.player_car.type, - ), - ) - - def _score_and_reset(self, state: UpNDownState) -> UpNDownState: - player_goal = state.ball_x < 4 - enemy_goal = state.ball_x > 156 - ball_reset = jnp.logical_or(enemy_goal, player_goal) - - player_score = jax.lax.cond( - player_goal, - lambda s: s + 1, - lambda s: s, - operand=state.player_score, - ) - enemy_score = jax.lax.cond( - enemy_goal, - lambda s: s + 1, - lambda s: s, - operand=state.enemy_score, - ) - - current_values = ( - state.ball_x.astype(jnp.int32), - state.ball_y.astype(jnp.int32), - state.ball_vel_x.astype(jnp.int32), - state.ball_vel_y.astype(jnp.int32), - ) - ball_x_final, ball_y_final, ball_vel_x_final, ball_vel_y_final = jax.lax.cond( - ball_reset, - lambda x: self._reset_ball_after_goal((state, enemy_goal)), - lambda x: x, - operand=current_values, - ) - - step_counter = jax.lax.cond( - ball_reset, - lambda s: jnp.array(0), - lambda s: s + 1, - operand=state.step_counter, - ) - - enemy_y_final = jax.lax.cond( - ball_reset, - lambda s: self.consts.BALL_START_Y.astype(jnp.int32), - lambda s: state.enemy_y.astype(jnp.int32), - operand=None, - ) - - ball_x_final = jax.lax.cond( - step_counter < 60, - lambda s: self.consts.BALL_START_X.astype(jnp.int32), - lambda s: s, - operand=ball_x_final, - ) - ball_y_final = jax.lax.cond( - step_counter < 60, - lambda s: self.consts.BALL_START_Y.astype(jnp.int32), - lambda s: s, - operand=ball_y_final, - ) - - return UpNDownState( - player_y=state.player_y, - player_speed=state.player_speed, - ball_x=ball_x_final, - ball_y=ball_y_final, - enemy_y=enemy_y_final, - enemy_speed=state.enemy_speed, - ball_vel_x=ball_vel_x_final, - ball_vel_y=ball_vel_y_final, - player_score=player_score, - enemy_score=enemy_score, - step_counter=step_counter, - acceleration_counter=state.acceleration_counter, - buffer=state.buffer, - ) - - def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: - state = UpNDownState( - player_y=jnp.array(96).astype(jnp.int32), - player_speed=jnp.array(0.0).astype(jnp.int32), - ball_x=jnp.array(78).astype(jnp.int32), - ball_y=jnp.array(115).astype(jnp.int32), - enemy_y=jnp.array(115).astype(jnp.int32), - enemy_speed=jnp.array(0.0).astype(jnp.int32), - ball_vel_x=self.consts.BALL_SPEED[0].astype(jnp.int32), - ball_vel_y=self.consts.BALL_SPEED[1].astype(jnp.int32), - player_score=jnp.array(0).astype(jnp.int32), - enemy_score=jnp.array(0).astype(jnp.int32), - step_counter=jnp.array(0).astype(jnp.int32), - acceleration_counter=jnp.array(0).astype(jnp.int32), - buffer=jnp.array(96).astype(jnp.int32), - ) - initial_obs = self._get_observation(state) - - return initial_obs, state - - @partial(jax.jit, static_argnums=(0,)) - def step(self, state: UpNDownState, action: chex.Array) -> Tuple[UpNDownObservation, UpNDownState, float, bool, UpNDownInfo]: - previous_state = state - state = self._player_step(state, action) - state = self._enemy_step(state) - state = self._ball_step(state, action) - state = self._score_and_reset(state) - - done = self._get_done(state) - env_reward = self._get_reward(previous_state, state) - info = self._get_info(state) - observation = self._get_observation(state) - - return observation, state, env_reward, done, info - - - def render(self, state: UpNDownState) -> jnp.ndarray: - return self.renderer.render(state) - - def _get_observation(self, state: UpNDownState): - player = EntityPosition( - x=jnp.array(self.consts.PLAYER_X), - y=state.player_y, - width=jnp.array(self.consts.PLAYER_SIZE[0]), - height=jnp.array(self.consts.PLAYER_SIZE[1]), - ) - - enemy = EntityPosition( - x=jnp.array(self.consts.ENEMY_X), - y=state.enemy_y, - width=jnp.array(self.consts.ENEMY_SIZE[0]), - height=jnp.array(self.consts.ENEMY_SIZE[1]), - ) - - ball = EntityPosition( - x=state.ball_x, - y=state.ball_y, - width=jnp.array(self.consts.BALL_SIZE[0]), - height=jnp.array(self.consts.BALL_SIZE[1]), - ) - return UpNDownObservation( - player=player, - enemy=enemy, - ball=ball, - score_player=state.player_score, - score_enemy=state.enemy_score, - ) - - @partial(jax.jit, static_argnums=(0,)) - def obs_to_flat_array(self, obs: UpNDownObservation) -> jnp.ndarray: - return jnp.concatenate([ - obs.player.x.flatten(), - obs.player.y.flatten(), - obs.player.height.flatten(), - obs.player.width.flatten(), - obs.enemy.x.flatten(), - obs.enemy.y.flatten(), - obs.enemy.height.flatten(), - obs.enemy.width.flatten(), - obs.ball.x.flatten(), - obs.ball.y.flatten(), - obs.ball.height.flatten(), - obs.ball.width.flatten(), - obs.score_player.flatten(), - obs.score_enemy.flatten() - ] - ) - - def action_space(self) -> spaces.Discrete: - return spaces.Discrete(6) - - def observation_space(self) -> spaces: - return spaces.Dict({ - "player": spaces.Dict({ - "x": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), - "y": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), - "width": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), - "height": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), - }), - "enemy": spaces.Dict({ - "x": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), - "y": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), - "width": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), - "height": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), - }), - "ball": spaces.Dict({ - "x": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), - "y": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), - "width": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), - "height": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), - }), - "score_player": spaces.Box(low=0, high=21, shape=(), dtype=jnp.int32), - "score_enemy": spaces.Box(low=0, high=21, shape=(), dtype=jnp.int32), - }) - - def image_space(self) -> spaces.Box: - return spaces.Box( - low=0, - high=255, - shape=(210, 160, 3), - dtype=jnp.uint8 - ) - - @partial(jax.jit, static_argnums=(0,)) - def _get_info(self, state: UpNDownState, ) -> UpNDownInfo: - return UpNDownInfo(time=state.step_counter) - - @partial(jax.jit, static_argnums=(0,)) - def _get_reward(self, previous_state: UpNDownState, state: UpNDownState): - return (state.player_score - state.enemy_score) - ( - previous_state.player_score - previous_state.enemy_score - ) - - @partial(jax.jit, static_argnums=(0,)) - def _get_done(self, state: UpNDownState) -> bool: - return jnp.logical_or( - jnp.greater_equal(state.player_score, 21), - jnp.greater_equal(state.enemy_score, 21), - ) - -class UpNDownRenderer(JAXGameRenderer): - def __init__(self, consts: UpNDownConstants = None): - super().__init__() - self.consts = consts or UpNDownConstants() - self.config = render_utils.RendererConfig( - game_dimensions=(210, 160), - channels=3, - #downscale=(84, 84) - ) - self.jr = render_utils.JaxRenderingUtils(self.config) - # 1. Create procedural assets for both walls - wall_sprite_top = self._create_wall_sprite(self.consts.WALL_TOP_HEIGHT) - wall_sprite_bottom = self._create_wall_sprite(self.consts.WALL_BOTTOM_HEIGHT) - - # 2. Update asset config to include both walls - asset_config = self._get_asset_config(wall_sprite_top, wall_sprite_bottom) - sprite_path = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/UpNDown" - - # 3. Make a single call to the setup function - ( - self.PALETTE, - self.SHAPE_MASKS, - self.BACKGROUND, - self.COLOR_TO_ID, - self.FLIP_OFFSETS - ) = self.jr.load_and_setup_assets(asset_config, sprite_path) - - def _create_wall_sprite(self, height: int) -> jnp.ndarray: - """Procedurally creates an RGBA sprite for a wall of given height.""" - wall_color_rgba = (*self.consts.SCORE_COLOR, 255) # e.g., (236, 236, 236, 255) - wall_shape = (height, self.consts.WIDTH, 4) - wall_sprite = jnp.tile(jnp.array(wall_color_rgba, dtype=jnp.uint8), (*wall_shape[:2], 1)) - return wall_sprite - - def _get_asset_config(self, wall_sprite_top: jnp.ndarray, wall_sprite_bottom: jnp.ndarray) -> list: - """Returns the declarative manifest of all assets for the game, including both wall sprites.""" - return [ - {'name': 'background', 'type': 'background', 'file': 'background.npy'}, - {'name': 'player', 'type': 'single', 'file': 'player.npy'}, - {'name': 'enemy', 'type': 'single', 'file': 'enemy.npy'}, - {'name': 'ball', 'type': 'single', 'file': 'ball.npy'}, - {'name': 'player_digits', 'type': 'digits', 'pattern': 'player_score_{}.npy'}, - {'name': 'enemy_digits', 'type': 'digits', 'pattern': 'enemy_score_{}.npy'}, - # Add the procedurally created sprites to the manifest - {'name': 'wall_top', 'type': 'procedural', 'data': wall_sprite_top}, - {'name': 'wall_bottom', 'type': 'procedural', 'data': wall_sprite_bottom}, - ] - - @partial(jax.jit, static_argnums=(0,)) - def render(self, state): - raster = self.jr.create_object_raster(self.BACKGROUND) - - player_mask = self.SHAPE_MASKS["player"] - raster = self.jr.render_at(raster, self.consts.PLAYER_X, state.player_y, player_mask) - - enemy_mask = self.SHAPE_MASKS["enemy"] - raster = self.jr.render_at(raster, self.consts.ENEMY_X, state.enemy_y, enemy_mask) - - ball_mask = self.SHAPE_MASKS["ball"] - raster = self.jr.render_at(raster, state.ball_x, state.ball_y, ball_mask) - - # --- Stamp Walls and Score (using the same color/ID) --- - score_color_tuple = self.consts.SCORE_COLOR # (236, 236, 236) - score_id = self.COLOR_TO_ID[score_color_tuple] - - # Draw walls (using separate sprites for top and bottom) - raster = self.jr.render_at(raster, 0, self.consts.WALL_TOP_Y, self.SHAPE_MASKS["wall_top"]) - raster = self.jr.render_at(raster, 0, self.consts.WALL_BOTTOM_Y, self.SHAPE_MASKS["wall_bottom"]) - - # Stamp Score using the label utility - player_digits = self.jr.int_to_digits(state.player_score, max_digits=2) - enemy_digits = self.jr.int_to_digits(state.enemy_score, max_digits=2) - - # Note: The logic for single/double digits is complex for a jitted function. - player_digit_masks = self.SHAPE_MASKS["player_digits"] # Assumes single color - enemy_digit_masks = self.SHAPE_MASKS["enemy_digits"] # Assumes single color - - is_player_single_digit = state.player_score < 10 - player_start_index = jax.lax.select(is_player_single_digit, 1, 0) - player_num_to_render = jax.lax.select(is_player_single_digit, 1, 2) - player_render_x = jax.lax.select(is_player_single_digit, - 120 + 16 // 2, - 120) - - raster = self.jr.render_label_selective(raster, player_render_x, 3, player_digits, player_digit_masks, player_start_index, player_num_to_render, spacing=16) - - is_enemy_single_digit = state.enemy_score < 10 - enemy_start_index = jax.lax.select(is_enemy_single_digit, 1, 0) - enemy_num_to_render = jax.lax.select(is_enemy_single_digit, 1, 2) - enemy_render_x = jax.lax.select(is_enemy_single_digit, - 10 + 16 // 2, - 10) - - raster = self.jr.render_label_selective(raster, enemy_render_x, 3, enemy_digits, enemy_digit_masks, enemy_start_index, enemy_num_to_render, spacing=16) - - return self.jr.render_from_palette(raster, self.PALETTE) \ No newline at end of file diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 4f7a3af1f..4d63a6455 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -1,8 +1,7 @@ +from jax._src.pjit import JitWrapped import os from functools import partial from typing import NamedTuple, Tuple - -import jax import jax.lax import jax.numpy as jnp import chex @@ -15,89 +14,17 @@ class UpNDownConstants(NamedTuple): FRAME_SKIP: int = 4 DIFFICULTIES: chex.Array = jnp.array([0, 1, 2, 3, 4, 5]) - MAX_SPEED: int = 7 - INITIAL_LIVES: int = 5 - JUMP_ARC_HEIGHT: float = 22.0 - RESPAWN_DELAY_FRAMES: int = 60 - RESPAWN_Y: int = 0 - RESPAWN_X: int = 30 - ALL_FLAGS_BONUS: int = 1000 - # Enemy spawning and movement - MAX_ENEMY_CARS: int = 8 - ENEMY_SPAWN_INTERVAL_BASE: int = 30 # Base spawn interval - ENEMY_SPAWN_INTERVAL_MAX: int = 60 # Max spawn interval when many enemies exist - ENEMY_MIN_VISIBLE_COUNT: int = 2 # Minimum enemies to keep on screen - ENEMY_VISIBLE_DISTANCE: int = 120 # Distance within which enemies are considered "visible" - ENEMY_DESPAWN_DISTANCE: int = 250 - ENEMY_SPEED_MIN: int = 3 - ENEMY_SPEED_MAX: int = 5 - ENEMY_DIRECTION_SWITCH_PROB: float = 0.0001 - ENEMY_SPAWN_OFFSET_MIN: float = 70.0 # Closer spawn distance - ENEMY_SPAWN_OFFSET_MAX: float = 130.0 # Max spawn offset - ENEMY_MIN_SPAWN_GAP: float = 25.0 # Reduced gap between spawns - ENEMY_MAX_AGE: int = 1900 - INITIAL_ENEMY_COUNT: int = 4 - INITIAL_ENEMY_BASE_OFFSET: float = 35.0 # Closer initial enemies - INITIAL_ENEMY_GAP: float = 25.0 # Tighter initial spacing - ENEMY_TYPE_CAMERO: int = 0 - ENEMY_TYPE_FLAG_CARRIER: int = 1 - ENEMY_TYPE_PICKUP: int = 2 - ENEMY_TYPE_TRUCK: int = 3 - JUMP_FRAMES: int = 28 - POST_JUMP_DELAY: int = 10 - LANDING_TOLERANCE: int = 20 # Pixels tolerance for landing on a road (increased by 5 for wider landing zone) - LATE_JUMP_COLLISION_FRAMES: int = 2 - LANDING_COLLISION_DISTANCE: float = 12.0 # Larger collision distance when landing (increased for easier enemy kills) - GROUND_COLLISION_DISTANCE: float = 3.0 # Tight collision distance for ground collisions - LATE_JUMP_ENEMY_SCORE: int = 400 - STEEP_ROAD_SPEED_REDUCTION_INTERVAL: int = 8 # Frames between each speed reduction on steep roads - PASSIVE_SCORE_INTERVAL: int = 60 # Steps between passive score awards - PASSIVE_SCORE_AMOUNT: int = 10 # Points awarded for passive scoring - COLLISION_THRESHOLD: float = 5.0 # Distance threshold for flag/collectible collision - ACCELERATION_INTERVAL: int = 6 # Frames between speed changes when holding up/down - EXTRA_LIFE_THRESHOLD: int = 10000 # Score threshold for extra life - TRACK_LENGTH: int = 1036 - FIRST_TRACK_CORNERS_X: chex.Array = jnp.array([30, 75, 128, 75, 21, 75, 131, 111, 150, 95, 150, 115, 150, 108, 150, 115, 115, 75, 18, 38, 67, 38, 38, 20, 64, 30]) - TRACK_CORNERS_Y: chex.Array = jnp.array([0, -40, -98, -155, -203, -268, -327, -347, -382, -467, -525, -565, -597, -625, -670, -705, -738, -788, -838, -862, -898, -925, -950, -972, -1000, -1035]) - SECOND_TRACK_CORNERS_X: chex.Array = jnp.array([115, 75, 20, 75, 133, 75, 22, 37, 63, 27, 66, 30, 63, 24, 60, 38, 38, 75, 131, 111, 150, 118, 118, 98, 150, 115]) + ACTION_REPEAT_PROBS: float = 0.25 + MAX_SPEED: int = 4 + JUMP_FRAMES: int = 10 + LANDING_ZONE: int = 15 + FIRST_ROAD_LENGTH: int = 4 + SECOND_ROAD_LENGTH: int = 4 + FIRST_TRACK_CORNERS_X: chex.Array = jnp.array([20, 50, 80, 100]) #get actual values + FIRST_TRACK_CORNERS_Y: chex.Array = jnp.array([20, 50, 80, 100]) #get actual values + SECOND_TRACK_CORNERS_X: chex.Array = jnp.array([20, 50, 80, 100]) #get actual values + SECOND_TRACK_CORNERS_Y: chex.Array = jnp.array([20, 50, 80, 100]) #get actual values PLAYER_SIZE: Tuple[int, int] = (4, 16) - INITIAL_ROAD_POS_Y: int = 25 - # Flag constants - 8 flags with different colors matching the top row - NUM_FLAGS: int = 8 - # Flag colors as RGBA values (matching the top row from left to right) - FLAG_COLORS: chex.Array = jnp.array([ - [184, 50, 50, 255], # Red - [181, 83, 40, 255], # Orange - [162, 98, 33, 255], # Dark orange - [134, 134, 29, 255], # Yellow/olive - [200, 72, 72, 255], # Pink (original) - [168, 48, 143, 255], # Magenta - [125, 48, 173, 255], # Purple - [78, 50, 181, 255], # Blue - ]) - # Top display positions for each flag (x coordinates where blackout squares appear) - FLAG_TOP_X_POSITIONS: chex.Array = jnp.array([13, 30, 47, 64, 82, 98, 118, 134]) - FLAG_TOP_Y: int = 20 - FLAG_BLACKOUT_SIZE: Tuple[int, int] = (14, 14) # Size of blackout square - FLAG_COLLECTION_SCORE: int = 75 # Points awarded for collecting a flag - # Life display constants - positions of life cars at the bottom - LIFE_BOTTOM_X_POSITIONS: chex.Array = jnp.array([13, 18, 25, 33, 33]) # X positions for 5 life cars - LIFE_BOTTOM_Y: int = 195 - # Collectible constants - unified dynamic spawning - MAX_COLLECTIBLES: int = 1 # Maximum collectibles that can exist at once (pool of mixed types) - COLLECTIBLE_SPAWN_INTERVAL: int = 200 # Steps between spawn attempts - COLLECTIBLE_DESPAWN_DISTANCE: int = 500 # Distance beyond which collectibles despawn - # Collectible types (indices for type field) - COLLECTIBLE_TYPE_CHERRY: int = 0 - COLLECTIBLE_TYPE_BALLOON: int = 1 - COLLECTIBLE_TYPE_LOLLYPOP: int = 2 - COLLECTIBLE_TYPE_ICE_CREAM: int = 3 - # Collectible type spawn probabilities (cumulative thresholds for random sampling) - COLLECTIBLE_SPAWN_PROBABILITIES: chex.Array = jnp.array([35, 65, 90, 100], dtype=jnp.int32) # Cherry: 35%, Balloon: 30%, Lollypop: 25%, IceCream: 10% - # Collectible type scores - COLLECTIBLE_SCORES: chex.Array = jnp.array([50, 65, 70, 75], dtype=jnp.int32) # [cherry, balloon, lollypop, ice_cream] - # Shared collectible colors - COLLECTIBLE_COLORS: chex.Array = FLAG_COLORS @@ -117,99 +44,30 @@ class Car(NamedTuple): road_index_B: chex.Array direction_x: chex.Array -class Flag(NamedTuple): - """Represents a collectible flag on the road.""" - y: chex.Array # Y position in world coordinates (like player_car.position.y) - road: chex.Array # Which road the flag is on (0 or 1) - road_segment: chex.Array # Which road segment index the flag is on - color_idx: chex.Array # Index into FLAG_COLORS array - collected: chex.Array # Whether this flag has been collected - -class Collectible(NamedTuple): - """Represents a dynamically spawning collectible item on the road. - - Can be any type: cherry (0), balloon (1), lollypop (2), or ice cream (3). - The type determines the sprite and point value. - """ - y: chex.Array # Y position in world coordinates - x: chex.Array # X position on the road - road: chex.Array # Which road the collectible is on (0 or 1) - color_idx: chex.Array # Index into COLLECTIBLE_COLORS array - type_id: chex.Array # Type of collectible (0=cherry, 1=balloon, 2=lollypop, 3=ice_cream) - active: chex.Array # Whether this collectible slot is active (spawned) - - -class EnemyCars(NamedTuple): - """Pool of enemy cars that share the same road-following logic as the player.""" - position: EntityPosition # vectorized position fields, size MAX_ENEMY_CARS - speed: chex.Array # signed speed per car - type: chex.Array # type id per car - current_road: chex.Array - road_index_A: chex.Array - road_index_B: chex.Array - direction_x: chex.Array - active: chex.Array - age: chex.Array - class UpNDownState(NamedTuple): score: chex.Array difficulty: chex.Array jump_cooldown: chex.Array - post_jump_cooldown: chex.Array is_jumping: chex.Array is_on_road: chex.Array player_car: Car - lives: chex.Array - is_dead: chex.Array - respawn_timer: chex.Array - step_counter: chex.Array - rng_key: chex.PRNGKey - round_started: chex.Array - movement_steps: chex.Array - steep_road_timer: chex.Array # Timer for steep road speed reduction - jump_slope: chex.Array # X movement per Y step, locked at jump start (float) - # Flag state - tracks all 8 flags - flags: Flag # Contains arrays of size NUM_FLAGS for each field - flags_collected_mask: chex.Array # Boolean mask of which flag colors have been collected (size NUM_FLAGS) - # Collectible state - dynamic spawning (mixed types: cherry, balloon, lollypop, ice cream) - collectibles: Collectible # Contains arrays of size MAX_COLLECTIBLES for each field - collectible_spawn_timer: chex.Array # Counter for collectible spawn timing - # Enemy cars - dynamic spawning and movement - enemy_cars: EnemyCars - enemy_spawn_timer: chex.Array - # Death/respawn state - player is dead and waiting for input to respawn - awaiting_respawn: chex.Array # True when player died and is waiting for input - # Round start state - everything frozen and hidden until player presses input - awaiting_round_start: chex.Array # True at game start and after respawn until input received - # Input debounce - requires button release before next input triggers round start - input_released: chex.Array # True when player has released all buttons since last state change - jump_key_released: chex.Array # True if jump button was NOT pressed in previous step - last_extra_life_score: chex.Array # Score at which last extra life was awarded - jump_total_duration: chex.Array # Total duration of the current/last jump for rendering arc + class UpNDownObservation(NamedTuple): - player_car: Car - enemy_cars: EnemyCars - flags: Flag - collectibles: Collectible - flags_collected_mask: chex.Array # Shape (NUM_FLAGS,) - int32 (0 or 1) - player_score: chex.Array - lives: chex.Array - is_jumping: chex.Array - jump_cooldown: chex.Array - is_on_steep_road: chex.Array - round_started: chex.Array + player: EntityPosition + +class Collectible(NamedTuple): + position: EntityPosition + type: chex.Array + value: chex.Array class UpNDownInfo(NamedTuple): - """Additional info for debugging and analysis.""" - step_counter: jnp.ndarray # Total steps taken - difficulty: jnp.ndarray # Current difficulty level - movement_steps: jnp.ndarray # Steps since round started - jump_slope: jnp.ndarray # Current jump trajectory slope - player_road_segment: jnp.ndarray # Current road segment index + time: jnp.ndarray + + class JaxUpNDown(JaxEnvironment[UpNDownState, UpNDownObservation, UpNDownInfo, UpNDownConstants]): def __init__(self, consts: UpNDownConstants = None, reward_funcs: list[callable]=None): consts = consts or UpNDownConstants() @@ -226,1523 +84,194 @@ def __init__(self, consts: UpNDownConstants = None, reward_funcs: list[callable] Action.DOWN, Action.DOWNFIRE, ] - # Calculate obs_size based on observation structure: - # Player car: 10 values (x, y, w, h, speed, type, road, road_index_A, road_index_B, direction_x) - # Enemy cars: MAX_ENEMY_CARS * 12 = 8 * 12 = 96 (x, y, w, h, speed, type, road, road_index_A, road_index_B, direction_x, active, age) - # Flags: NUM_FLAGS * 5 = 8 * 5 = 40 (y, road, segment, color, collected per flag) - # Collectibles: MAX_COLLECTIBLES * 6 = 1 * 6 = 6 (y, x, road, color_idx, type, active per collectible) - # Flags collected mask: NUM_FLAGS = 8 - # Score, lives, is_jumping, jump_cooldown, is_on_steep_road, round_started: 6 - # Total: 10 + 96 + 40 + 6 + 8 + 6 = 166 - self.obs_size = ( - 10 + # player car - self.consts.MAX_ENEMY_CARS * 12 + # enemy cars (all fields) - self.consts.NUM_FLAGS * 5 + # flags - self.consts.MAX_COLLECTIBLES * 6 + # collectibles (all fields) - self.consts.NUM_FLAGS + # flags_collected_mask - 6 # score, lives, is_jumping, jump_cooldown, is_on_steep_road, round_started - ) - # Speed dividers for movement timing (indexed by speed level) - self._speed_dividers = jnp.array([0, 1, 2, 4, 8, 16, 16, 16, 16]) - - @partial(jax.jit, static_argnums=(0,)) - def _compute_movement_timing(self, speed: chex.Array, step_counter: chex.Array) -> Tuple[chex.Array, chex.Array, chex.Array, chex.Array]: - """Calculate movement timing parameters based on speed. - - Returns: - Tuple of (move_y, move_x, step_size, speed_sign) - """ - abs_speed = jnp.abs(speed) - speed_index = jnp.minimum(abs_speed, jnp.int32(self._speed_dividers.shape[0] - 1)) - speed_divider = self._speed_dividers[speed_index] - effective_divider = jnp.maximum(1, speed_divider) - period = jnp.maximum(1, 16 // effective_divider) - half_period = jnp.maximum(1, period // 2) - speed_sign = jnp.sign(speed).astype(jnp.float32) - - move_y = jnp.logical_and((step_counter % period) == (half_period % period), speed != 0) - move_x = jnp.logical_and((step_counter % period) == 0, speed != 0) - step_size = jnp.where(speed_index >= 6, 1.5 + (speed_index - 6) * 0.2, 1.0) - - return move_y, move_x, step_size, speed_sign + self.obs_size = 3*4+1+1 @partial(jax.jit, static_argnums=(0,)) - def _sample_enemy_spawn_road(self, rng_key: chex.PRNGKey) -> chex.Array: - """Sample road index for enemy spawns. - - Extracted as a modding hook; default behavior is unchanged. - """ - return jax.random.randint(rng_key, shape=(), minval=0, maxval=2).astype(jnp.int32) - - @partial(jax.jit, static_argnums=(0,)) - def _collectible_score_values(self, state: UpNDownState, collectible_type_ids: chex.Array) -> chex.Array: - """Return score values for collectible types. - - Extracted as a modding hook; default behavior is unchanged. - """ - return self.consts.COLLECTIBLE_SCORES[collectible_type_ids] - - @partial(jax.jit, static_argnums=(0,)) - def _on_level_completed(self, state: UpNDownState) -> UpNDownState: - """Optional callback invoked only when all flags are collected. - - Default is a no-op and preserves existing game behavior. - """ - return state - - @partial(jax.jit, static_argnums=(0,)) - def _jump_speed_allows_start(self, player_speed: chex.Array) -> chex.Array: - """Return whether jump start is allowed for the current speed. - - Extracted as a modding hook; default behavior is unchanged. - """ - return player_speed >= 0 - - @partial(jax.jit, static_argnums=(0,)) - def _adjust_enemy_spawn_timer(self, state: UpNDownState, spawn_timer: chex.Array) -> chex.Array: - """Optional hook to post-process enemy spawn timer. - - Extracted as a modding hook; default behavior is unchanged. - """ - return spawn_timer - - @partial(jax.jit, static_argnums=(0,)) - def _get_slope_and_intercept_from_indices(self, current_road: chex.Array, road_index_A: chex.Array, road_index_B: chex.Array) -> Tuple[chex.Array, chex.Array]: - """Calculate slope and intercept for the current road segment.""" - road_index = jnp.where(current_road == 0, road_index_A, road_index_B) - x1 = jnp.where(current_road == 0, - self.consts.FIRST_TRACK_CORNERS_X[road_index], - self.consts.SECOND_TRACK_CORNERS_X[road_index]) - x2 = jnp.where(current_road == 0, - self.consts.FIRST_TRACK_CORNERS_X[road_index + 1], - self.consts.SECOND_TRACK_CORNERS_X[road_index + 1]) - y1 = self.consts.TRACK_CORNERS_Y[road_index] - y2 = self.consts.TRACK_CORNERS_Y[road_index + 1] - - dx = x2 - x1 - dy = y2 - y1 - slope = jnp.where(dx != 0, dy / dx, 300.0) - b = y1 - slope * x1 - return slope, b - - @partial(jax.jit, static_argnums=(0,)) - def _is_on_line_for_position(self, position: EntityPosition, slope: chex.Array, b: chex.Array, player_speed: chex.Array, turn: chex.Array) -> chex.Array: - x_step = abs(jnp.subtract(position.y, slope * (position.x) + b)) - y_step = abs(jnp.subtract(position.y - player_speed, slope * position.x + b)) - prefer_y = jnp.less_equal(y_step, x_step) - return jnp.logical_or( - jnp.logical_and(turn == 1, prefer_y), - jnp.logical_and(turn == 2, jnp.logical_not(prefer_y)), - ) - - @partial(jax.jit, static_argnums=(0,)) - def _get_x_on_road(self, y: chex.Array, road_segment: chex.Array, track_corners_x: chex.Array) -> chex.Array: - """Calculate the X position on a road given a Y coordinate and road segment.""" - y1 = self.consts.TRACK_CORNERS_Y[road_segment] - y2 = self.consts.TRACK_CORNERS_Y[road_segment + 1] - x1 = track_corners_x[road_segment] - x2 = track_corners_x[road_segment + 1] - - # Linear interpolation: x = x1 + (y - y1) * (x2 - x1) / (y2 - y1) - t = jnp.where(y2 != y1, (y - y1) / (y2 - y1), 0.0) - return x1 + t * (x2 - x1) - - @partial(jax.jit, static_argnums=(0,)) - def _get_x_for_road_index(self, y: chex.Array, road_segment: chex.Array, road_index: chex.Array) -> chex.Array: - """Get X position on road A (index 0) or road B (index 1) for given Y and segment.""" - track_corners = jnp.where( - road_index == 0, - self.consts.FIRST_TRACK_CORNERS_X[road_segment], - self.consts.SECOND_TRACK_CORNERS_X[road_segment], - ) - track_corners_next = jnp.where( - road_index == 0, - self.consts.FIRST_TRACK_CORNERS_X[road_segment + 1], - self.consts.SECOND_TRACK_CORNERS_X[road_segment + 1], - ) - y1 = self.consts.TRACK_CORNERS_Y[road_segment] - y2 = self.consts.TRACK_CORNERS_Y[road_segment + 1] - t = jnp.where(y2 != y1, (y - y1) / (y2 - y1), 0.0) - return track_corners + t * (track_corners_next - track_corners) - - @partial(jax.jit, static_argnums=(0,)) - def _get_road_segment(self, y: chex.Array) -> chex.Array: - """Return the road segment index for a given y position.""" - segments = jnp.sum(self.consts.TRACK_CORNERS_Y > y, dtype=jnp.int32) - max_idx = jnp.int32(len(self.consts.TRACK_CORNERS_Y) - 1) - return jnp.clip(segments - 1, 0, max_idx) - - @partial(jax.jit, static_argnums=(0,)) - def _compute_direction_x(self, current_road: chex.Array, road_index_A: chex.Array, road_index_B: chex.Array) -> chex.Array: - """Calculate the X direction for movement on the current road segment. - - Returns: - Direction as int32: -1 for left, 1 for right (defaults to -1 for vertical segments) - """ - # Select the road index based on which road we're on - road_index = jnp.where(current_road == 0, road_index_A, road_index_B) - # Select corners for the current road - x_curr = jnp.where(current_road == 0, - self.consts.FIRST_TRACK_CORNERS_X[road_index], - self.consts.SECOND_TRACK_CORNERS_X[road_index]) - x_next = jnp.where(current_road == 0, - self.consts.FIRST_TRACK_CORNERS_X[road_index + 1], - self.consts.SECOND_TRACK_CORNERS_X[road_index + 1]) - direction_raw = x_next - x_curr - return jnp.where(direction_raw == 0, -1, jnp.sign(direction_raw)).astype(jnp.int32) - - @partial(jax.jit, static_argnums=(0,)) - def _move_on_road( - self, - position: EntityPosition, - slope: chex.Array, - b: chex.Array, - speed_sign: chex.Array, - step_size: chex.Array, - car_direction_x: chex.Array, - move_y: chex.Array, - move_x: chex.Array, - ) -> Tuple[chex.Array, chex.Array]: - """Move a car on the road based on timing and geometry. - - Returns: - Tuple of (new_x, new_y) positions - """ - new_y = jnp.where( - jnp.logical_and(move_y, self._is_on_line_for_position(position, slope, b, speed_sign, 1)), - position.y + speed_sign * -step_size, - position.y, - ) - - new_x = jnp.where( - jnp.logical_and(move_x, self._is_on_line_for_position(position, slope, b, speed_sign, 2)), - position.x + speed_sign * car_direction_x * step_size, - position.x, - ) - - return new_x, new_y - - @partial(jax.jit, static_argnums=(0,)) - def _is_steep_road_segment(self, current_road: chex.Array, road_index_A: chex.Array, road_index_B: chex.Array) -> chex.Array: - """Check if the current road segment is steep (no X direction change). - - A steep segment is one where the X coordinates of consecutive corners are the same, - meaning the road goes straight up/down with no horizontal movement. - - Returns True if the segment is steep (requires jump to pass when going up). - """ - # Get the X difference for the current road segment - road_index = jnp.where(current_road == 0, road_index_A, road_index_B) - x_curr = jnp.where(current_road == 0, - self.consts.FIRST_TRACK_CORNERS_X[road_index], - self.consts.SECOND_TRACK_CORNERS_X[road_index]) - x_next = jnp.where(current_road == 0, - self.consts.FIRST_TRACK_CORNERS_X[road_index + 1], - self.consts.SECOND_TRACK_CORNERS_X[road_index + 1]) - x_diff = jnp.abs(x_next - x_curr) - # A segment is steep if there's no X change (or very small change) - return x_diff < 1.0 - - @partial(jax.jit, static_argnums=(0,)) - def _get_steep_segment_progress(self, position_y: chex.Array, current_road: chex.Array, - road_index_A: chex.Array, road_index_B: chex.Array) -> chex.Array: - """Calculate progress (0.0 to 1.0) through the current steep road segment. - - 0.0 = at the bottom (start) of the steep segment - 1.0 = at the top (end) of the steep segment - - Progress is measured in the direction of forward travel (upward = positive Y direction in game space, - but Y decreases as we go forward on the track). - """ - road_index = jnp.where(current_road == 0, road_index_A, road_index_B) - # Y coordinates of segment boundaries - y_start = self.consts.TRACK_CORNERS_Y[road_index] # Start of segment (lower Y = further ahead) - y_end = self.consts.TRACK_CORNERS_Y[road_index + 1] # End of segment (higher Y in absolute terms) - - # Calculate progress: how far through the segment are we? - # Since Y decreases as we go forward, we need to invert - segment_length = jnp.abs(y_end - y_start) - # Distance from segment start (in forward direction) - distance_from_start = jnp.abs(position_y - y_start) - - progress = jnp.where(segment_length > 0.001, distance_from_start / segment_length, 0.0) - return jnp.clip(progress, 0.0, 1.0) - - @partial(jax.jit, static_argnums=(0,)) - def _check_landing_position( - self, - road_index_A: chex.Array, - road_index_B: chex.Array, - new_position_x: chex.Array, - new_position_y: chex.Array, - ) -> Tuple[chex.Array, chex.Array, chex.Array, chex.Array]: - """Check if a position is valid for landing (on or between roads). - - Returns: - Tuple of (landing_in_water, between_roads, road_A_x, road_B_x) - """ - # Calculate X position on road A at the given Y - y_ratio_A = (new_position_y - self.consts.TRACK_CORNERS_Y[road_index_A]) / ( - self.consts.TRACK_CORNERS_Y[road_index_A + 1] - self.consts.TRACK_CORNERS_Y[road_index_A] - ) - road_A_x = y_ratio_A * ( - self.consts.FIRST_TRACK_CORNERS_X[road_index_A + 1] - self.consts.FIRST_TRACK_CORNERS_X[road_index_A] - ) + self.consts.FIRST_TRACK_CORNERS_X[road_index_A] - - # Calculate X position on road B at the given Y - y_ratio_B = (new_position_y - self.consts.TRACK_CORNERS_Y[road_index_B]) / ( - self.consts.TRACK_CORNERS_Y[road_index_B + 1] - self.consts.TRACK_CORNERS_Y[road_index_B] - ) - road_B_x = y_ratio_B * ( - self.consts.SECOND_TRACK_CORNERS_X[road_index_B + 1] - self.consts.SECOND_TRACK_CORNERS_X[road_index_B] - ) + self.consts.SECOND_TRACK_CORNERS_X[road_index_B] - - distance_to_road_A = jnp.abs(new_position_x - road_A_x) - distance_to_road_B = jnp.abs(new_position_x - road_B_x) - landing_in_water = jnp.logical_and( - distance_to_road_A > self.consts.LANDING_TOLERANCE, - distance_to_road_B > self.consts.LANDING_TOLERANCE, - ) - between_roads = jnp.logical_and( - new_position_x > jnp.minimum(road_A_x, road_B_x), - new_position_x < jnp.maximum(road_A_x, road_B_x), - ) - return landing_in_water, between_roads, road_A_x, road_B_x - - @partial(jax.jit, static_argnums=(0,)) - def _advance_player_car( - self, - position_x: chex.Array, - position_y: chex.Array, - road_index_A: chex.Array, - road_index_B: chex.Array, - current_road: chex.Array, - speed: chex.Array, - is_jumping: chex.Array, - step_counter: chex.Array, - width: chex.Array, - height: chex.Array, - car_type: chex.Array, - is_landing: chex.Array, - stored_jump_slope: chex.Array, - jump_progress: chex.Array, - ) -> Car: - """ - Advance the player car position. - - Jump logic: - - Car jumps in the direction of the road it's on at current speed - - While jumping, car moves freely (not constrained to road) - - On landing: check if car is on/near a road or between roads - - If between roads: snap to nearest road - - If too far from both roads (outside the road area): crash (water) - """ - # Calculate movement timing using helper - move_y, move_x, step_size, speed_sign = self._compute_movement_timing(speed, step_counter) - - # Get slope and intercept for current road - slope, b = self._get_slope_and_intercept_from_indices(current_road, road_index_A, road_index_B) - - # Determine X direction based on current road segment (for normal movement) - car_direction_x = self._compute_direction_x(current_road, road_index_A, road_index_B) - - position = EntityPosition(x=position_x, y=position_y, width=width, height=height) - - # === CALCULATE ROAD-BASED MOVEMENT (used when not jumping) === - road_x, road_y = self._move_on_road( - position, slope, b, speed_sign, step_size, car_direction_x, move_y, move_x - ) - - # === JUMP PHYSICS NORMALIZATION === - # Normalize jump velocity so total speed (Euclidean) matches 'step_size' - # Without this, diagonal jumps cover more distance per frame than straight road movement - # stored_jump_slope is dX/dY - # Scaling factor = 1 / sqrt(1 + slope^2) - jump_speed_scaling = 1.0 / jnp.sqrt(1.0 + stored_jump_slope**2) - jump_step_size = step_size * jump_speed_scaling - - # === Y MOVEMENT === - # When jumping: move freely in Y direction but with normalized speed - # When on road: use road-based movement result - # Note: We must apply step_y on move_y ticks to keep sync with engine heartbeat - jump_y = jnp.where(move_y, position_y + speed_sign * -jump_step_size, position_y) - new_player_y = jnp.where(is_jumping, jump_y, road_y) - - # === X MOVEMENT === - # When jumping: use stored_jump_slope (locked at jump start) - moves X proportionally to Y - # Use jump_step_size to maintain correct trajectory and speed - # X step = slope * Y step magnitude = slope * jump_step_size - raw_jump_x = jnp.where(move_x, position_x - speed_sign * stored_jump_slope * jump_step_size, position_x) - - # === AIR STEERING / MAGNETISM === - # Gradually steer towards the nearest road while in the air to prevent "teleporting" on landing - segment_curr = self._get_road_segment(new_player_y) - road_A_x_curr = self._get_x_on_road(new_player_y, segment_curr, self.consts.FIRST_TRACK_CORNERS_X) - road_B_x_curr = self._get_x_on_road(new_player_y, segment_curr, self.consts.SECOND_TRACK_CORNERS_X) - - dist_A = jnp.abs(raw_jump_x - road_A_x_curr) - dist_B = jnp.abs(raw_jump_x - road_B_x_curr) - - # Find closest road center - target_road_x = jnp.where(dist_A < dist_B, road_A_x_curr, road_B_x_curr) - dist_to_target = target_road_x - raw_jump_x - - # Only nudge in the last 25% of the jump (progress > 0.75) - # when reasonably close to a road (within 2x tolerance) - # and only when player is between the two roads - - is_late_jump = jump_progress > 0.75 - is_reasonably_close = jnp.abs(dist_to_target) < (self.consts.LANDING_TOLERANCE * 2.0) - - # Check if player is between the two roads - min_road_x_curr = jnp.minimum(road_A_x_curr, road_B_x_curr) - max_road_x_curr = jnp.maximum(road_A_x_curr, road_B_x_curr) - is_between_roads = jnp.logical_and(raw_jump_x > min_road_x_curr, raw_jump_x < max_road_x_curr) - - should_magnet = jnp.logical_and(is_late_jump, jnp.logical_and(is_reasonably_close, is_between_roads)) - - # Nudge factor: reduced to 2% steering strength (very subtle) - nudge_amount = dist_to_target * 0.08 - - jump_x = raw_jump_x + jnp.where(should_magnet, nudge_amount, 0.0) - - new_player_x = jnp.where(is_jumping, jump_x, road_x) - - # === LANDING LOGIC === - # Get the current road segment based on new Y position - segment = self._get_road_segment(new_player_y) - - # Calculate X positions of both roads at the new Y position - road_A_x = self._get_x_on_road(new_player_y, segment, self.consts.FIRST_TRACK_CORNERS_X) - road_B_x = self._get_x_on_road(new_player_y, segment, self.consts.SECOND_TRACK_CORNERS_X) - - # Calculate distances to each road - dist_to_road_A = jnp.abs(new_player_x - road_A_x) - dist_to_road_B = jnp.abs(new_player_x - road_B_x) - - # Check if player is close enough to either road (within tolerance) - on_road_A = dist_to_road_A <= self.consts.LANDING_TOLERANCE - on_road_B = dist_to_road_B <= self.consts.LANDING_TOLERANCE - on_any_road = jnp.logical_or(on_road_A, on_road_B) - - # Check if player is between the two roads - min_road_x = jnp.minimum(road_A_x, road_B_x) - max_road_x = jnp.maximum(road_A_x, road_B_x) - between_roads = jnp.logical_and(new_player_x > min_road_x, new_player_x < max_road_x) - - # Determine which road is closer - closer_to_A = dist_to_road_A < dist_to_road_B - nearest_road_x = jnp.where(closer_to_A, road_A_x, road_B_x) - nearest_road_id = jnp.where(closer_to_A, jnp.int32(0), jnp.int32(1)) - - # === LANDING OUTCOMES === - # Valid landing: on a road OR between roads (will snap to nearest) - valid_landing = jnp.logical_or(on_any_road, between_roads) - - # Bridge crossing physics: if speed is high, we can "skip" small water gaps (land on nearest road) - # In original game, bridges allow crossing without jumping if you have speed - can_bridge_gap = jnp.abs(speed) >= 5 - - # If landing and between roads but not directly on a road, snap to nearest road - should_snap = jnp.logical_and(is_landing, jnp.logical_and(between_roads, jnp.logical_not(on_any_road))) - # Also snap if we are "in water" but have speed to bridge the gap - should_snap_bridge = jnp.logical_and(is_landing, jnp.logical_and(can_bridge_gap, jnp.logical_not(valid_landing))) - - final_player_x = jnp.where(jnp.logical_or(should_snap, should_snap_bridge), nearest_road_x, new_player_x) - - # Water landing (crash): Only if NOT on road AND NOT between roads (i.e., landed completely outside) - # User clarification: "crashing should only be possible if you dont land in betweeen or on the roads" - - # Safe if: ON ROAD or BETWEEN ROADS - is_safe_landing = jnp.logical_or(on_any_road, between_roads) - - landing_in_water = jnp.logical_and( - is_landing, - jnp.logical_not(is_safe_landing) - ) - - # Snap logic: - # If landing BETWEEN roads but not ON a road -> snap to nearest (safe!) - # (Outside landings are now crashes, so no need to snap them) - should_snap = jnp.logical_and(is_landing, jnp.logical_and(between_roads, jnp.logical_not(on_any_road))) - - # Also snap if bridging (fast jump across water gap) - should_snap_bridge = jnp.logical_and(is_landing, jnp.logical_and(between_roads, can_bridge_gap)) - - final_player_x = jnp.where( - jnp.logical_or(should_snap, should_snap_bridge), - nearest_road_x, - new_player_x + def _car_past_corner(self, car: Car, state: UpNDownState) -> chex.Array: + direction_change_A = jnp.logical_or(jnp.logical_and(car.speed > 0, car.position.y + car.speed > self.consts.FIRST_TRACK_CORNERS_Y[car.road_index_A+1]), jnp.logical_and(car.speed < 0, car.position.y + car.speed < self.consts.FIRST_TRACK_CORNERS_Y[car.road_index_A])) + direction_change_B = jnp.logical_or(jnp.logical_and(car.speed > 0, car.position.y + car.speed > self.consts.SECOND_TRACK_CORNERS_Y[car.road_index_B+1]), jnp.logical_and(car.speed < 0, car.position.y + car.speed < self.consts.SECOND_TRACK_CORNERS_Y[car.road_index_B])) + + road_index_A = jax.lax.cond(jnp.logical_and(direction_change_A, car.speed > 0), + lambda s: s + 1, + lambda s: s, + operand=car.road_index_A, ) - - # === UPDATE ROAD STATE === - # Determine which road to assign on landing (priority: road A > road B > nearest) - landed_road = jnp.where(on_road_A, jnp.int32(0), jnp.where(on_road_B, jnp.int32(1), nearest_road_id)) - - # Update current_road using nested jnp.where for vectorized execution - # Priority: water crash > landing > jumping (frozen) > recover from water > normal - normal_road = jnp.where(current_road == 2, nearest_road_id, current_road) - jumping_road = jnp.where(is_jumping, current_road, normal_road) - landing_road = jnp.where(is_landing, landed_road, jumping_road) - updated_current_road = jnp.where(landing_in_water, jnp.int32(2), landing_road) - - # Update road indices to match current segment when not jumping - not_jumping_on_road_A = jnp.logical_and(jnp.logical_not(is_jumping), updated_current_road == 0) - not_jumping_on_road_B = jnp.logical_and(jnp.logical_not(is_jumping), updated_current_road == 1) - next_road_index_A = jnp.where(not_jumping_on_road_A, segment, road_index_A) - next_road_index_B = jnp.where(not_jumping_on_road_B, segment, road_index_B) - - # Wrap Y position for looping track - wrapped_y = -((new_player_y * -1) % self.consts.TRACK_LENGTH) - - return Car( - position=EntityPosition( - x=final_player_x, - y=wrapped_y, - width=width, - height=height, - ), - speed=speed, - direction_x=car_direction_x, - current_road=updated_current_road, - road_index_A=next_road_index_A, - road_index_B=next_road_index_B, - type=car_type, + road_index_A = jax.lax.cond(jnp.logical_and(direction_change_A, car.speed < 0), + lambda s: s - 1, + lambda s: s, + operand=car.road_index_A, ) - @partial(jax.jit, static_argnums=(0,)) - def _advance_car_core( - self, - position_x: chex.Array, - position_y: chex.Array, - road_index_A: chex.Array, - road_index_B: chex.Array, - current_road: chex.Array, - speed: chex.Array, - step_counter: chex.Array, - width: chex.Array, - height: chex.Array, - car_type: chex.Array, - ) -> Car: - """Simplified car advancement for enemy cars (no jumping/landing logic).""" - # Calculate movement timing using helper - move_y, move_x, step_size, speed_sign = self._compute_movement_timing(speed, step_counter) - slope, b = self._get_slope_and_intercept_from_indices(current_road, road_index_A, road_index_B) - car_direction_x = self._compute_direction_x(current_road, road_index_A, road_index_B) - - position = EntityPosition(x=position_x, y=position_y, width=width, height=height) - - # Use shared movement helper - new_x, new_y = self._move_on_road( - position, slope, b, speed_sign, step_size, car_direction_x, move_y, move_x + road_index_B = jax.lax.cond(jnp.logical_and(direction_change_B, car.speed > 0), + lambda s: s + 1, + lambda s: s, + operand=car.road_index_B, ) - - wrapped_y = -((new_y * -1) % self.consts.TRACK_LENGTH) - - # Update road segment indices based on new position - segment_from_y = self._get_road_segment(new_y) - - # Update road indices to track the current segment (use jnp.where for branchless execution) - next_road_index_A = jnp.where(current_road == 0, segment_from_y, road_index_A) - next_road_index_B = jnp.where(current_road == 1, segment_from_y, road_index_B) - - return Car( - position=EntityPosition( - x=new_x, - y=wrapped_y, - width=width, - height=height, - ), - speed=speed, - direction_x=car_direction_x, - current_road=current_road, - road_index_A=next_road_index_A, - road_index_B=next_road_index_B, - type=car_type, + road_index_B = jax.lax.cond(jnp.logical_and(direction_change_B, car.speed < 0), + lambda s: s - 1, + lambda s: s, + operand=car.road_index_B, ) + current_road_length_A = self.consts.FIRST_ROAD_LENGTH + current_road_length_B = self.consts.SECOND_ROAD_LENGTH - @partial(jax.jit, static_argnums=(0,)) - def _flag_step(self, state: UpNDownState, new_player_y: chex.Array, player_x: chex.Array, current_road: chex.Array) -> Tuple[Flag, chex.Array, chex.Array]: - """Update flag collection state and score (vectorized).""" - # Calculate flag X positions on both roads - # _get_x_on_road supports array inputs via advanced indexing - x_road_0 = self._get_x_on_road(state.flags.y, state.flags.road_segment, self.consts.FIRST_TRACK_CORNERS_X) - x_road_1 = self._get_x_on_road(state.flags.y, state.flags.road_segment, self.consts.SECOND_TRACK_CORNERS_X) - - flag_x = jnp.where(state.flags.road == 0, x_road_0, x_road_1) - - # Vectorized distance check - y_dist = jnp.abs(new_player_y - state.flags.y) - x_dist = jnp.abs(player_x - flag_x) - same_road = (current_road == state.flags.road) - - new_collections = jnp.logical_and( - jnp.logical_and(y_dist < self.consts.COLLISION_THRESHOLD, x_dist < self.consts.COLLISION_THRESHOLD), - jnp.logical_and(same_road, ~state.flags.collected) - ) - - # Update flags collected state - new_flags_collected = jnp.logical_or(state.flags.collected, new_collections) - new_flags_collected_mask = jnp.logical_or(state.flags_collected_mask, new_collections) - - # Update score based on collected flags - flag_score = jnp.sum(new_collections.astype(jnp.int32) * self.consts.FLAG_COLLECTION_SCORE) - - new_flags = Flag( - y=state.flags.y, - road=state.flags.road, - road_segment=state.flags.road_segment, - color_idx=state.flags.color_idx, - collected=new_flags_collected, + road_index_A = jax.lax.cond(road_index_A < 0, + lambda s: current_road_length_A - 1, + lambda s: s, + operand=road_index_A, ) - - return new_flags, flag_score, new_flags_collected_mask - - @partial(jax.jit, static_argnums=(0,)) - def _collectible_step(self, state: UpNDownState, new_player_y: chex.Array, player_x: chex.Array, current_road: chex.Array, rng_key: chex.PRNGKey) -> Tuple[Collectible, chex.Array, chex.Array, chex.PRNGKey]: - """Update collectible spawning, despawning, and collection (unified for all types). - Handles mixed-type collectibles (cherry, balloon, lollypop, ice cream) in a single pool. - Type is randomized on spawn with probabilities defined in COLLECTIBLE_SPAWN_PROBABILITIES. - - Args: - state: Current game state - new_player_y: Updated player Y position after movement - player_x: Current player X position - current_road: Current road player is on - rng_key: PRNG key to drive spawn randomness - - Returns: - Tuple of (updated_collectibles, score_delta, new_spawn_timer, new_rng_key) - """ - rng_key, key1, key2, key3, key4 = jax.random.split(rng_key, 5) - - # Collectible spawning logic - decrement timer and spawn when ready (use jnp.where for branchless) - new_collectible_timer = jnp.where( - state.collectible_spawn_timer <= 0, - self.consts.COLLECTIBLE_SPAWN_INTERVAL, - state.collectible_spawn_timer - 1, - ) - - # Attempt to spawn when timer hits 0 - should_spawn = state.collectible_spawn_timer <= 0 - - inactive_mask = ~state.collectibles.active - first_inactive = jnp.argmax(inactive_mask.astype(jnp.int32)) - has_inactive_slot = jnp.any(inactive_mask) - spawn_idx = jnp.where(has_inactive_slot, first_inactive, jnp.array(0, dtype=jnp.int32)) - - y_spawn = jax.random.uniform(key1, minval=-900.0, maxval=-100.0) - road_spawn = jnp.array(jax.random.randint(key2, shape=(), minval=0, maxval=2), dtype=jnp.int32) - color_spawn = jnp.array(jax.random.randint(key3, shape=(), minval=0, maxval=len(self.consts.COLLECTIBLE_COLORS)), dtype=jnp.int32) - - # Randomly select collectible type using cumulative probability thresholds - # COLLECTIBLE_SPAWN_PROBABILITIES contains cumulative values: [35, 65, 90, 100] - # Cherry: [0-35), Balloon: [35-65), Lollypop: [65-90), IceCream: [90-100] - rand_type = jax.random.uniform(key4, minval=0.0, maxval=100.0) - - # Use searchsorted for efficient threshold lookup - type_id_spawn = jnp.searchsorted(self.consts.COLLECTIBLE_SPAWN_PROBABILITIES, rand_type, side='right') - type_id_spawn = jnp.clip(type_id_spawn, 0, 3).astype(jnp.int32) - - # Calculate X position on road (use jnp.where for branchless) - segment_spawn = self._get_road_segment(y_spawn) - x_spawn = jnp.where( - road_spawn == 0, - self._get_x_on_road(y_spawn, segment_spawn, self.consts.FIRST_TRACK_CORNERS_X), - self._get_x_on_road(y_spawn, segment_spawn, self.consts.SECOND_TRACK_CORNERS_X), - ) - - # Create mask for which collectibles to update - update_mask = (jnp.arange(self.consts.MAX_COLLECTIBLES) == spawn_idx) & should_spawn & has_inactive_slot - - # Update collectibles with proper masking - spawn new items - spawned_y = jnp.where(update_mask, y_spawn, state.collectibles.y) - spawned_x = jnp.where(update_mask, x_spawn, state.collectibles.x) - spawned_road = jnp.where(update_mask, road_spawn, state.collectibles.road) - spawned_color_idx = jnp.where(update_mask, color_spawn, state.collectibles.color_idx) - spawned_type_id = jnp.where(update_mask, type_id_spawn, state.collectibles.type_id) - spawned_active = jnp.where(update_mask, True, state.collectibles.active) - - # Despawn logic - remove collectibles too far from player - def check_despawn(idx): - c_y = spawned_y[idx] - c_active = spawned_active[idx] - distance = jnp.abs(new_player_y - c_y) - too_far = distance > self.consts.COLLECTIBLE_DESPAWN_DISTANCE - should_despawn = jnp.logical_and(c_active, too_far) - return should_despawn - - despawn_mask = jax.vmap(check_despawn)(jnp.arange(self.consts.MAX_COLLECTIBLES)) - active_after_despawn = jnp.logical_and(spawned_active, ~despawn_mask) - - # Collision detection - def check_collision(idx): - c_y = spawned_y[idx] - c_x = spawned_x[idx] - c_road = spawned_road[idx] - c_active = spawned_active[idx] - - y_distance = jnp.abs(new_player_y - c_y) - x_distance = jnp.abs(player_x - c_x) - same_road = (current_road == c_road) - - collision = jnp.logical_and( - jnp.logical_and(y_distance < self.consts.COLLISION_THRESHOLD, x_distance < self.consts.COLLISION_THRESHOLD), - jnp.logical_and(same_road, c_active) - ) - return collision - - collections = jax.vmap(check_collision)(jnp.arange(self.consts.MAX_COLLECTIBLES)) - - # Deactivate collected items - final_active = jnp.logical_and(active_after_despawn, ~collections) - - # Update score - extracted into hook for easier modding - scores = self._collectible_score_values(state, spawned_type_id) - score_delta = jnp.sum(jnp.where(collections, scores, 0)) - - # Create final collectibles state - updated_collectibles = Collectible( - y=spawned_y, - x=spawned_x, - road=spawned_road, - color_idx=spawned_color_idx, - type_id=spawned_type_id, - active=final_active, + road_index_A = jax.lax.cond(road_index_A >= current_road_length_A, + lambda s: 0, + lambda s: s, + operand=road_index_A, ) - - return updated_collectibles, score_delta, new_collectible_timer, rng_key - @partial(jax.jit, static_argnums=(0,)) - def _death_step(self, state: UpNDownState) -> UpNDownState: - """Handle player death - this is now only used for water crashes during landing. - - When the player dies: - - Lives are decremented - - is_dead is set to True - - awaiting_respawn is set to True - - Player car is moved off-screen (despawned) - - Game waits for player input before respawning - """ - # Skip if already awaiting respawn - already_awaiting = state.awaiting_respawn - - # Player on water road (index 2 assumed water) and not already dead - died = jnp.logical_and( - jnp.logical_and( - state.player_car.current_road == 2, - ~state.is_dead, - ), - ~already_awaiting, + road_index_B = jax.lax.cond(road_index_B < 0, + lambda s: current_road_length_B - 1, + lambda s: s, + operand=road_index_B, ) - # Use jnp.where for branchless execution - lives = jnp.where(died, state.lives - 1, state.lives) - is_dead = jnp.logical_or(state.is_dead, died) - awaiting_respawn = jnp.logical_or(state.awaiting_respawn, died) - - # Stop player movement but keep position (renderer will hide player when awaiting_respawn) - player_car = state.player_car._replace( - speed=jnp.where(died, 0, state.player_car.speed), + road_index_B = jax.lax.cond(road_index_B >= current_road_length_B, + lambda s: 0, + lambda s: s, + operand=road_index_B, ) - return state._replace( - lives=lives, - is_dead=is_dead, - awaiting_respawn=awaiting_respawn, - player_car=player_car, - ) + return road_index_A, road_index_B - @partial(jax.jit, static_argnums=(0,)) + def _landing_in_water(self, state: UpNDownState, new_position_x: chex.Array, new_position_y: chex.Array) -> chex.Array: + road_A_x = ((new_position_y - self.consts.FIRST_TRACK_CORNERS_Y[state.player_car.road_index_A]) / (self.consts.FIRST_TRACK_CORNERS_Y[state.player_car.road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_Y[state.player_car.road_index_A])) * (self.consts.FIRST_TRACK_CORNERS_X[state.player_car.road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_X[state.player_car.road_index_A]) + self.consts.FIRST_TRACK_CORNERS_X[state.player_car.road_index_A] + road_B_x = ((new_position_y - self.consts.SECOND_TRACK_CORNERS_Y[state.player_car.road_index_B]) / (self.consts.SECOND_TRACK_CORNERS_Y[state.player_car.road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_Y[state.player_car.road_index_B])) * (self.consts.SECOND_TRACK_CORNERS_X[state.player_car.road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[state.player_car.road_index_B]) + self.consts.SECOND_TRACK_CORNERS_X[state.player_car.road_index_B] + distance_to_road_A = jnp.abs(new_position_x - road_A_x) + distance_to_road_B = jnp.abs(new_position_x - road_B_x) + landing_in_Water = jnp.logical_and(distance_to_road_A > self.consts.LANDING_ZONE, distance_to_road_B > self.consts.LANDING_ZONE) + between_roads = jnp.logical_and(new_position_x > jnp.minimum(road_A_x, road_B_x), new_position_x < jnp.maximum(road_A_x, road_B_x)) + return landing_in_Water, between_roads, road_A_x, road_B_x + def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: up = jnp.logical_or(action == Action.UP, action == Action.UPFIRE) down = jnp.logical_or(action == Action.DOWN, action == Action.DOWNFIRE) - jump_pressed = jnp.logical_or(action == Action.FIRE, jnp.logical_or(action == Action.UPFIRE, action == Action.DOWNFIRE)) - - # Check if on a steep road section FIRST (before applying speed changes) - is_on_steep_road = self._is_steep_road_segment( - state.player_car.current_road, - state.player_car.road_index_A, - state.player_car.road_index_B, - ) - - # Calculate progress through steep segment (0.0 = bottom, 1.0 = top) - steep_progress = self._get_steep_segment_progress( - state.player_car.position.y, - state.player_car.current_road, - state.player_car.road_index_A, - state.player_car.road_index_B, - ) - - # Determine if player is on steep road going up (not jumping) - on_steep_not_jumping = jnp.logical_and(is_on_steep_road, jnp.logical_not(state.is_jumping)) - - # Start with current speed - player_speed = state.player_car.speed - - # === FRICTION & MOMENTUM LOGIC === - is_accelerating = up - is_braking = down - - # No friction - speed stays constant when no input - # Speed changes gradually (periodically, not every frame) - should_change_speed = (state.step_counter % self.consts.ACCELERATION_INTERVAL) == 0 - - # === ACCELERATION (UP) === - # On steep road: UP action has NO effect (can't accelerate while on steep section) - can_accelerate = jnp.logical_not(on_steep_not_jumping) - - player_speed = jnp.where( - jnp.logical_and( - jnp.logical_and(should_change_speed, is_accelerating), - jnp.logical_and(player_speed < self.consts.MAX_SPEED, can_accelerate) - ), - player_speed + 1, - player_speed, - ) - - # === BRAKING (DOWN) === - # DOWN action always works (can brake/reverse) - player_speed = jnp.where( - jnp.logical_and( - jnp.logical_and(should_change_speed, is_braking), - player_speed > -self.consts.MAX_SPEED - ), - player_speed - 1, - player_speed, - ) - - # === STEEP ROAD SPEED REDUCTION & SLIDE BACK === - # Only apply when on steep road, not jumping, and trying to go up (positive speed) - on_steep_going_up = jnp.logical_and(on_steep_not_jumping, player_speed > 0) - - # Update steep road timer - increment when on steep road going up - steep_road_timer = jnp.where( - on_steep_going_up, - state.steep_road_timer + 1, - jnp.array(0, dtype=jnp.int32), - ) - - # Check if player has reached halfway point (50% progress through segment) - past_halfway = steep_progress >= 0.5 - - # Check if player has enough momentum to climb steep road - MIN_CLIMB_SPEED = 5 - has_momentum = player_speed >= MIN_CLIMB_SPEED - - # Two behaviors based on progress: - # 1. Before halfway: gradually reduce speed using timer - # 2. At/past halfway: immediately slide back UNLESS we have enough momentum - - # Before halfway: reduce speed periodically using timer - should_reduce_speed = jnp.logical_and( - on_steep_going_up, - jnp.logical_and( - jnp.logical_not(past_halfway), - steep_road_timer >= self.consts.STEEP_ROAD_SPEED_REDUCTION_INTERVAL - ) - ) - player_speed = jnp.where( - should_reduce_speed, - jnp.maximum(player_speed - 1, jnp.int32(0)), # Reduce but not below 0 yet - player_speed, - ) - # Reset timer after speed reduction - steep_road_timer = jnp.where( - should_reduce_speed, - jnp.array(0, dtype=jnp.int32), - steep_road_timer, - ) - - # At/past halfway: force speed to -2 (slide back down) IF momentum is lost - should_slide_back = jnp.logical_and( - on_steep_going_up, - jnp.logical_and(past_halfway, jnp.logical_not(has_momentum)) - ) - player_speed = jnp.where( - should_slide_back, - jnp.int32(-3), - player_speed, - ) + jump = jnp.logical_or(action == Action.FIRE, jnp.logical_or(action == Action.UPFIRE, action == Action.DOWNFIRE)) - # === JUMP LOGIC === - can_start_jump = jnp.logical_and( - state.jump_cooldown == 0, - jnp.logical_and(state.post_jump_cooldown == 0, state.jump_key_released) - ) - is_jumping = jnp.logical_or( - jnp.logical_and(state.is_jumping, state.jump_cooldown > 0), - jnp.logical_and( - state.is_on_road, - jnp.logical_and( - self._jump_speed_allows_start(player_speed), - jnp.logical_and(can_start_jump, jump_pressed), - ), - ), - ) - - # Detect when a new jump is starting (was not jumping, now is jumping) - starting_jump = jnp.logical_and(is_jumping, jnp.logical_not(state.is_jumping)) - - # Calculate jump slope at jump start (X change per Y step) - # Uses the road segment slope to follow the road trajectory - # Use jnp.where for branchless execution - road_index = jnp.where( - state.player_car.current_road == 0, - state.player_car.road_index_A, - state.player_car.road_index_B, - ) - - # Get corner coordinates for the current segment - # Segment goes from corner[road_index] to corner[road_index+1] - # Use jnp.where for branchless execution - start_x = jnp.where( - state.player_car.current_road == 0, - self.consts.FIRST_TRACK_CORNERS_X[road_index], - self.consts.SECOND_TRACK_CORNERS_X[road_index], - ) - end_x = jnp.where( - state.player_car.current_road == 0, - self.consts.FIRST_TRACK_CORNERS_X[road_index + 1], - self.consts.SECOND_TRACK_CORNERS_X[road_index + 1], - ) - start_y = self.consts.TRACK_CORNERS_Y[road_index] - end_y = jnp.where( - jnp.equal(self.consts.FIRST_TRACK_CORNERS_X[road_index + 1], self.consts.FIRST_TRACK_CORNERS_X[road_index + 2]), - self.consts.TRACK_CORNERS_Y[road_index + 2], - self.consts.TRACK_CORNERS_Y[road_index + 1], - ) - - # Calculate slope: how much X changes per unit Y change - delta_x = end_x - start_x - delta_y = end_y - start_y - # Avoid division by zero for horizontal segments (use jnp.where) - new_jump_slope = jnp.where( - jnp.abs(delta_y) > 0.001, - jnp.float32(delta_x) / jnp.float32(delta_y), - jnp.float32(0.0), - ) - - # Lock slope at jump start, keep previous slope during jump (use jnp.where) - jump_slope = jnp.where(starting_jump, new_jump_slope, state.jump_slope) - - # Calculate dynamic jump duration based on speed - # Faster speed = shorter jump duration (covering gap faster) - # Increased base duration for more "air time" as requested - # Formula: 48 - 2 * abs(speed) -> Speed 8 = 32 frames (was 24 before) - current_jump_duration = 48 - 2 * jnp.abs(player_speed) - jump_duration = jnp.where(starting_jump, current_jump_duration.astype(jnp.int32), state.jump_total_duration) - - # Use jnp.where for branchless execution of jump_cooldown - jump_cooldown = jnp.where( - state.jump_cooldown > 0, - state.jump_cooldown - 1, - jnp.where(is_jumping, jump_duration, 0), - ) - # Use jnp.where for branchless execution of post_jump_cooldown - is_landing_now = jnp.logical_and(state.jump_cooldown == 1, jump_cooldown == 0) - post_jump_cooldown = jnp.where( - is_landing_now, - self.consts.POST_JUMP_DELAY, - jnp.where(state.post_jump_cooldown > 0, state.post_jump_cooldown - 1, state.post_jump_cooldown), - ) - is_on_road = ~is_jumping - is_landing = is_landing_now - - # Calculate jump progress for magnetism - # Progress = (Total - Remaining) / Total - # Use jnp.maximum(..., 1.0) to avoid division by zero - safe_total_duration = jnp.maximum(state.jump_total_duration, 1.0) - jump_progress = (safe_total_duration - jump_cooldown.astype(jnp.float32)) / safe_total_duration - jump_progress = jnp.clip(jump_progress, 0.0, 1.0) - - updated_player_car = self._advance_player_car( - position_x=state.player_car.position.x, - position_y=state.player_car.position.y, - road_index_A=state.player_car.road_index_A, - road_index_B=state.player_car.road_index_B, - current_road=state.player_car.current_road, - speed=player_speed, - is_jumping=is_jumping, - step_counter=state.step_counter, - width=state.player_car.position.width, - height=state.player_car.position.height, - car_type=state.player_car.type, - is_landing=is_landing, - stored_jump_slope=jump_slope, - jump_progress=jump_progress, - ) - - # Check if a speed-changing action (UP or DOWN) was taken - speed_action_taken = jnp.logical_or(up, down) - # Round starts only after a speed-changing action - round_started_now = jnp.logical_or(state.round_started, speed_action_taken) - - # Track jump key release for preventing held-key jumps - next_jump_key_released = jnp.logical_not(jump_pressed) - - next_state = state._replace( - jump_cooldown=jump_cooldown, - post_jump_cooldown=post_jump_cooldown, - is_jumping=is_jumping, - is_on_road=is_on_road, - player_car=updated_player_car, - step_counter=state.step_counter + 1, - round_started=round_started_now, - movement_steps=jnp.where(round_started_now, state.movement_steps + 1, state.movement_steps), - steep_road_timer=steep_road_timer, - jump_slope=jump_slope, - jump_key_released=next_jump_key_released, - jump_total_duration=jump_duration, - ) - - water_crash = jnp.logical_and(is_landing, updated_player_car.current_road == 2) - - # On water crash, trigger death state instead of immediate respawn - def trigger_death(s): - # Stop player but keep position (renderer will hide player when awaiting_respawn) - dead_car = s.player_car._replace( - speed=jnp.array(0, dtype=jnp.int32), - ) - return s._replace( - lives=s.lives - 1, - is_dead=jnp.array(True), - awaiting_respawn=jnp.array(True), - player_car=dead_car, - ) - - return jax.lax.cond( - water_crash, - lambda _: trigger_death(next_state), - lambda _: next_state, - operand=None, - ) - - @partial(jax.jit, static_argnums=(0,)) - def _flag_step_main(self, state: UpNDownState) -> UpNDownState: - """Update flag collection state and score.""" - new_player_y = state.player_car.position.y - player_x = state.player_car.position.x - current_road = state.player_car.current_road - - new_flags, flag_score, new_flags_collected_mask = self._flag_step( - state, new_player_y, player_x, current_road - ) - - return state._replace( - score=state.score + flag_score, - flags=new_flags, - flags_collected_mask=new_flags_collected_mask, - ) + player_speed = state.player_car.speed - @partial(jax.jit, static_argnums=(0,)) - def _level_progression_step(self, state: UpNDownState) -> UpNDownState: - """Handle level completion: award bonus and reset flags.""" - all_flags_collected = jnp.all(state.flags_collected_mask) - - bonus = jnp.where(all_flags_collected, self.consts.ALL_FLAGS_BONUS, 0) - - # Reset flags if all collected - new_collected = jnp.where(all_flags_collected, jnp.zeros_like(state.flags.collected), state.flags.collected) - new_mask = jnp.where(all_flags_collected, jnp.zeros_like(state.flags_collected_mask), state.flags_collected_mask) - - updated_flags = state.flags._replace(collected=new_collected) - - next_state = state._replace( - score=state.score + bonus, - flags=updated_flags, - flags_collected_mask=new_mask + player_speed = jax.lax.cond( + jnp.logical_and(state.player_car.speed < self.consts.MAX_SPEED, up), + lambda s: s + 1, + lambda s: s, + operand=player_speed, ) - return jax.lax.cond( - all_flags_collected, - lambda s: self._on_level_completed(s), + player_speed = jax.lax.cond( + jnp.logical_and(state.player_car.speed > -self.consts.MAX_SPEED, down), + lambda s: s - 1, lambda s: s, - next_state, + operand=player_speed, ) - @partial(jax.jit, static_argnums=(0,)) - def _extra_life_step(self, state: UpNDownState) -> UpNDownState: - """Award extra life every 10000 points.""" - next_milestone = state.last_extra_life_score + self.consts.EXTRA_LIFE_THRESHOLD - should_award = state.score >= next_milestone - - new_lives = jnp.where(should_award, state.lives + 1, state.lives) - new_last_score = jnp.where(should_award, next_milestone, state.last_extra_life_score) - - return state._replace(lives=new_lives, last_extra_life_score=new_last_score) - @partial(jax.jit, static_argnums=(0,)) - def _collectible_step_main(self, state: UpNDownState) -> UpNDownState: - """Update collectible spawning, despawning, and collection.""" - new_player_y = state.player_car.position.y - player_x = state.player_car.position.x - current_road = state.player_car.current_road - - updated_collectibles, collectible_score, new_collectible_timer, rng_key = self._collectible_step( - state, new_player_y, player_x, current_road, state.rng_key - ) - - return state._replace( - score=state.score + collectible_score, - collectibles=updated_collectibles, - collectible_spawn_timer=new_collectible_timer, - rng_key=rng_key, + is_jumping = jnp.logical_or(jnp.logical_and(state.is_jumping, state.jump_cooldown > 0), jnp.logical_and(state.is_on_road, jnp.logical_and(player_speed > 0, state.jump_cooldown == 0))) + jump_cooldown = jax.lax.cond( + state.jump_cooldown > 0, + lambda s: s - 1, + lambda s: jax.lax.cond(is_jumping, + lambda _: self.consts.JUMP_FRAMES, + lambda _: 0, + operand=None), + operand=state.jump_cooldown, ) - @partial(jax.jit, static_argnums=(0,)) - def _initialize_collectibles(self) -> Collectible: - """Return a cleared collectible pool.""" - return Collectible( - y=jnp.zeros(self.consts.MAX_COLLECTIBLES), - x=jnp.zeros(self.consts.MAX_COLLECTIBLES), - road=jnp.zeros(self.consts.MAX_COLLECTIBLES, dtype=jnp.int32), - color_idx=jnp.zeros(self.consts.MAX_COLLECTIBLES, dtype=jnp.int32), - type_id=jnp.zeros(self.consts.MAX_COLLECTIBLES, dtype=jnp.int32), - active=jnp.zeros(self.consts.MAX_COLLECTIBLES, dtype=jnp.bool_), - ) - @partial(jax.jit, static_argnums=(0,)) - def _initialize_enemies(self, key: chex.Array, player_start_y: chex.Array) -> EnemyCars: - """Seed the initial set of visible enemies around the player.""" - key_init, key_type, key_road, key_speed, key_sign = jax.random.split(key, 5) - - offsets = self.consts.INITIAL_ENEMY_BASE_OFFSET + self.consts.INITIAL_ENEMY_GAP * jnp.arange(self.consts.INITIAL_ENEMY_COUNT) - spawn_signs = jax.random.choice(key_sign, jnp.array([-1.0, 1.0]), shape=(self.consts.INITIAL_ENEMY_COUNT,)) - raw_spawn_y = player_start_y + spawn_signs * offsets - init_y = -(((raw_spawn_y) * -1) % self.consts.TRACK_LENGTH) - init_road = jax.random.randint(key_road, shape=(self.consts.INITIAL_ENEMY_COUNT,), minval=0, maxval=2) - - init_segments = jax.vmap(self._get_road_segment)(init_y) - - init_x = jax.vmap(lambda y, seg, road: jax.lax.cond( - road == 0, - lambda _: self._get_x_on_road(y, seg, self.consts.FIRST_TRACK_CORNERS_X), - lambda _: self._get_x_on_road(y, seg, self.consts.SECOND_TRACK_CORNERS_X), - operand=None, - ))(init_y, init_segments, init_road) - - init_type = jax.random.randint(key_type, shape=(self.consts.INITIAL_ENEMY_COUNT,), minval=0, maxval=4) - init_speed_mag = jax.random.randint(key_speed, shape=(self.consts.INITIAL_ENEMY_COUNT,), minval=self.consts.ENEMY_SPEED_MIN, maxval=self.consts.ENEMY_SPEED_MAX + 1) - init_speed_sign = jax.random.choice(key_init, jnp.array([-1, 1]), shape=(self.consts.INITIAL_ENEMY_COUNT,)) - init_speed = init_speed_mag * init_speed_sign - - def init_direction(seg, road): - raw = jax.lax.cond( - road == 0, - lambda _: self.consts.FIRST_TRACK_CORNERS_X[seg+1] - self.consts.FIRST_TRACK_CORNERS_X[seg], - lambda _: self.consts.SECOND_TRACK_CORNERS_X[seg+1] - self.consts.SECOND_TRACK_CORNERS_X[seg], - operand=None, - ) - return jax.lax.cond(raw > 0, lambda _: 1, lambda _: -1, operand=None) - init_dir = jax.vmap(init_direction)(init_segments, init_road) - pad = self.consts.MAX_ENEMY_CARS - self.consts.INITIAL_ENEMY_COUNT + ##check if player is on the the road + is_on_road = ~state.is_jumping - return EnemyCars( - position=EntityPosition( - x=jnp.concatenate([init_x, jnp.zeros(pad, dtype=jnp.float32)]), - y=jnp.concatenate([init_y, jnp.zeros(pad, dtype=jnp.float32)]), - width=jnp.full((self.consts.MAX_ENEMY_CARS,), self.consts.PLAYER_SIZE[0]), - height=jnp.full((self.consts.MAX_ENEMY_CARS,), self.consts.PLAYER_SIZE[1]), - ), - speed=jnp.concatenate([init_speed, jnp.zeros(pad, dtype=jnp.int32)]), - type=jnp.concatenate([init_type, jnp.zeros(pad, dtype=jnp.int32)]), - current_road=jnp.concatenate([init_road, jnp.zeros(pad, dtype=jnp.int32)]), - road_index_A=jnp.concatenate([init_segments, jnp.zeros(pad, dtype=jnp.int32)]), - road_index_B=jnp.concatenate([init_segments, jnp.zeros(pad, dtype=jnp.int32)]), - direction_x=jnp.concatenate([init_dir, jnp.zeros(pad, dtype=jnp.int32)]), - active=jnp.concatenate([jnp.ones(self.consts.INITIAL_ENEMY_COUNT, dtype=jnp.bool_), jnp.zeros(pad, dtype=jnp.bool_)]), - age=jnp.concatenate([jnp.zeros(self.consts.INITIAL_ENEMY_COUNT, dtype=jnp.int32), jnp.zeros(pad, dtype=jnp.int32)]), - ) + road_index_A, road_index_B = self._car_past_corner(state.player_car, state) - @partial(jax.jit, static_argnums=(0,)) - def _enemy_step_main(self, state: UpNDownState) -> UpNDownState: - """Spawn and move enemy cars with adaptive spawning for consistent enemy presence.""" - # Split RNG keys - use more splits to ensure better randomization - rng_key, key_spawn_offset, key_spawn_side, key_spawn_speed, key_spawn_direction, key_spawn_type, key_spawn_sign, key_flip_root, key_extra = jax.random.split(state.rng_key, 9) - - # Further split key_spawn_type to get more entropy for type selection - key_spawn_type = jax.random.fold_in(key_spawn_type, state.step_counter) - - active_mask = state.enemy_cars.active - active_count = jnp.sum(active_mask.astype(jnp.int32)) - can_spawn = active_count < self.consts.MAX_ENEMY_CARS - - # Calculate how many enemies are "visible" (within visible distance of player) - player_y = state.player_car.position.y - enemy_distances = jnp.abs(state.enemy_cars.position.y - player_y) - wrapped_distances = jnp.minimum(enemy_distances, self.consts.TRACK_LENGTH - enemy_distances) - visible_mask = jnp.logical_and(active_mask, wrapped_distances < self.consts.ENEMY_VISIBLE_DISTANCE) - visible_count = jnp.sum(visible_mask.astype(jnp.int32)) - - # Adaptive spawn interval: spawn faster when fewer visible enemies - # If below minimum, spawn immediately (interval = 0) - # Otherwise scale between BASE and MAX based on visible count - needs_urgent_spawn = visible_count < self.consts.ENEMY_MIN_VISIBLE_COUNT - spawn_interval = jnp.where( - needs_urgent_spawn, - jnp.int32(0), # Spawn immediately when too few visible - jnp.int32(self.consts.ENEMY_SPAWN_INTERVAL_BASE + - (visible_count * (self.consts.ENEMY_SPAWN_INTERVAL_MAX - self.consts.ENEMY_SPAWN_INTERVAL_BASE)) // - self.consts.MAX_ENEMY_CARS) + direction_change = jax.lax.cond( + jnp.logical_and(is_on_road, jnp.logical_or(jnp.logical_and(jnp.equal(road_index_A, state.player_car.road_index_A) , state.player_car.current_road == 0), (jnp.logical_and(jnp.equal(road_index_B, state.player_car.road_index_B) , state.player_car.current_road == 1)))) , + lambda s: False, + lambda s: True, + operand=None, ) - # Spawn when timer expires OR when we urgently need more enemies - timer_expired = state.enemy_spawn_timer <= 0 - should_spawn = jnp.logical_and( - jnp.logical_or(timer_expired, needs_urgent_spawn), - can_spawn - ) - - # Reset timer with adaptive interval - spawn_timer = jnp.where( - should_spawn, - spawn_interval, - jnp.maximum(state.enemy_spawn_timer - 1, 0), - ) - inactive_mask = jnp.logical_not(active_mask) - first_inactive = jnp.argmax(inactive_mask.astype(jnp.int32)) - has_inactive = jnp.any(inactive_mask) - spawn_idx = jnp.where(has_inactive, first_inactive, jnp.array(0, dtype=jnp.int32)) - spawn_mask = (jnp.arange(self.consts.MAX_ENEMY_CARS) == spawn_idx) & should_spawn & has_inactive - - # Spawn closer when urgent (fewer visible enemies), farther when plenty exist - base_offset = jnp.where( - needs_urgent_spawn, - self.consts.ENEMY_SPAWN_OFFSET_MIN, # Spawn closer when needed - self.consts.ENEMY_SPAWN_OFFSET_MIN + visible_count * 10.0 # Farther when plenty exist + car_direction_x = jax.lax.cond( + direction_change, + lambda s: jax.lax.cond(state.player_car.current_road == 0, + lambda s: self.consts.FIRST_TRACK_CORNERS_X[road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_X[road_index_A], + lambda s: self.consts.SECOND_TRACK_CORNERS_X[road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[road_index_B], + operand=None), + lambda s: s, + operand=state.player_car.direction_x, ) - spawn_offset = base_offset + jax.random.uniform(key_spawn_offset, minval=0.0, maxval=30.0) - spawn_side = jax.random.choice(key_spawn_side, jnp.array([-1.0, 1.0])) - raw_spawn_y = state.player_car.position.y + spawn_side * spawn_offset - spawn_y = -(((raw_spawn_y) * -1) % self.consts.TRACK_LENGTH) - spawn_road = self._sample_enemy_spawn_road(key_spawn_direction) - - segment_spawn = self._get_road_segment(spawn_y) - spawn_x = jnp.where( - spawn_road == 0, - self._get_x_on_road(spawn_y, segment_spawn, self.consts.FIRST_TRACK_CORNERS_X), - self._get_x_on_road(spawn_y, segment_spawn, self.consts.SECOND_TRACK_CORNERS_X), - ) + is_landing = jnp.logical_and(state.jump_cooldown == 1, jump_cooldown == 0) - spawn_speed_mag = jax.random.randint(key_spawn_speed, shape=(), minval=self.consts.ENEMY_SPEED_MIN, maxval=self.consts.ENEMY_SPEED_MAX + 1) - spawn_speed_sign = jax.random.choice(key_spawn_sign, jnp.array([-1, 1])) - spawn_speed = spawn_speed_mag * spawn_speed_sign - spawn_type = jax.random.randint(key_spawn_type, shape=(), minval=0, maxval=4) + ##calculate new position with speed (TODO: calculate better speed) + player_y = state.player_car.position.y + player_speed + player_x = state.player_car.position.x + player_speed * car_direction_x - direction_raw = jnp.where( - spawn_road == 0, - self.consts.FIRST_TRACK_CORNERS_X[segment_spawn+1] - self.consts.FIRST_TRACK_CORNERS_X[segment_spawn], - self.consts.SECOND_TRACK_CORNERS_X[segment_spawn+1] - self.consts.SECOND_TRACK_CORNERS_X[segment_spawn], - ) - spawn_direction_x = jnp.where(direction_raw > 0, 1, -1) - - enemy_position_x = jnp.where(spawn_mask, spawn_x, state.enemy_cars.position.x) - enemy_position_y = jnp.where(spawn_mask, spawn_y, state.enemy_cars.position.y) - enemy_width = state.enemy_cars.position.width - enemy_height = state.enemy_cars.position.height - enemy_speed = jnp.where(spawn_mask, spawn_speed, state.enemy_cars.speed) - enemy_type = jnp.where(spawn_mask, spawn_type, state.enemy_cars.type) - enemy_current_road = jnp.where(spawn_mask, spawn_road, state.enemy_cars.current_road) - enemy_road_index_A = jnp.where(spawn_mask, segment_spawn, state.enemy_cars.road_index_A) - enemy_road_index_B = jnp.where(spawn_mask, segment_spawn, state.enemy_cars.road_index_B) - enemy_direction_x = jnp.where(spawn_mask, spawn_direction_x, state.enemy_cars.direction_x) - enemy_active = jnp.where(spawn_mask, True, state.enemy_cars.active) - enemy_age = jnp.where(spawn_mask, jnp.zeros_like(state.enemy_cars.age), state.enemy_cars.age) - - flip_keys = jax.random.split(key_flip_root, self.consts.MAX_ENEMY_CARS) - flip_mask = jax.vmap(lambda k: jax.random.uniform(k) < self.consts.ENEMY_DIRECTION_SWITCH_PROB)(flip_keys) - enemy_speed = jnp.where(jnp.logical_and(enemy_active, flip_mask), -enemy_speed, enemy_speed) - - move_fn = lambda px, py, ra, rb, cr, sp, tp: self._advance_car_core( - position_x=px, - position_y=py, - road_index_A=ra, - road_index_B=rb, - current_road=cr, - speed=sp, - step_counter=state.step_counter, - width=self.consts.PLAYER_SIZE[0], - height=self.consts.PLAYER_SIZE[1], - car_type=tp, - ) - - advanced_cars = jax.vmap(move_fn)( - enemy_position_x, - enemy_position_y, - enemy_road_index_A, - enemy_road_index_B, - enemy_current_road, - enemy_speed, - enemy_type, - ) + landing_in_Water, between_roads, road_A_x, road_B_x = self._landing_in_water(state, player_x, player_y) + landing_in_Water = jnp.logical_and(is_landing, landing_in_Water) + - moved_position_x = jnp.where(enemy_active, advanced_cars.position.x, enemy_position_x) - moved_position_y = jnp.where(enemy_active, advanced_cars.position.y, enemy_position_y) - moved_road_index_A = jnp.where(enemy_active, advanced_cars.road_index_A, enemy_road_index_A) - moved_road_index_B = jnp.where(enemy_active, advanced_cars.road_index_B, enemy_road_index_B) - moved_current_road = jnp.where(enemy_active, advanced_cars.current_road, enemy_current_road) - moved_direction_x = jnp.where(enemy_active, advanced_cars.direction_x, enemy_direction_x) - - enemy_age = jnp.where(enemy_active, enemy_age + 1, enemy_age) - - delta_y = moved_position_y - state.player_car.position.y - wrapped_dist = jnp.minimum(jnp.abs(delta_y), self.consts.TRACK_LENGTH - jnp.abs(delta_y)) - far_mask = wrapped_dist > self.consts.ENEMY_DESPAWN_DISTANCE - age_mask = enemy_age > self.consts.ENEMY_MAX_AGE - despawn_mask = jnp.logical_and(enemy_active, jnp.logical_or(far_mask, age_mask)) - final_active = jnp.logical_and(enemy_active, jnp.logical_not(despawn_mask)) - enemy_age = jnp.where(despawn_mask, jnp.zeros_like(enemy_age), enemy_age) - - next_enemy_cars = EnemyCars( - position=EntityPosition( - x=moved_position_x, - y=moved_position_y, - width=enemy_width, - height=enemy_height, + current_road = jax.lax.cond( + landing_in_Water, + lambda s: 2, + lambda s: jax.lax.cond( + is_on_road, + lambda s: state.player_car.current_road, + lambda s: jax.lax.cond( + jnp.abs(player_x - road_A_x) < jnp.abs(player_x - road_B_x), + lambda s: 0, + lambda s: 1, + operand=None, + ), + operand=None, ), - speed=enemy_speed, - type=enemy_type, - current_road=moved_current_road, - road_index_A=moved_road_index_A, - road_index_B=moved_road_index_B, - direction_x=moved_direction_x, - active=final_active, - age=enemy_age, - ) - - spawn_timer = self._adjust_enemy_spawn_timer(state, spawn_timer) - - return state._replace( - enemy_cars=next_enemy_cars, - enemy_spawn_timer=spawn_timer, - rng_key=rng_key, - ) - - @partial(jax.jit, static_argnums=(0,)) - def _respawn_after_collision(self, state: UpNDownState, new_lives: chex.Array) -> UpNDownState: - """Respawn the player on a random road while preserving score and flags.""" - rng_key, road_key, enemy_key = jax.random.split(state.rng_key, 3) - - player_start_y = jnp.array(0.0) - start_segment = jnp.array(0, dtype=jnp.int32) - respawn_road = jax.random.randint(road_key, shape=(), minval=0, maxval=2) - - start_x = jax.lax.cond( - respawn_road == 0, - lambda _: self._get_x_on_road(player_start_y, start_segment, self.consts.FIRST_TRACK_CORNERS_X), - lambda _: self._get_x_on_road(player_start_y, start_segment, self.consts.SECOND_TRACK_CORNERS_X), operand=None, ) - - enemy_cars = self._initialize_enemies(enemy_key, player_start_y) - collectibles = self._initialize_collectibles() - - player_car = Car( - position=EntityPosition( - x=jnp.asarray(start_x, dtype=jnp.float32), - y=jnp.asarray(player_start_y, dtype=jnp.float32), - width=self.consts.PLAYER_SIZE[0], - height=self.consts.PLAYER_SIZE[1], - ), - speed=jnp.array(0, dtype=jnp.int32), - direction_x=jnp.array(0, dtype=jnp.int32), - current_road=respawn_road, - road_index_A=start_segment, - road_index_B=start_segment, - type=jnp.array(0, dtype=jnp.int32), - ) - return UpNDownState( score=state.score, - lives=new_lives, - is_dead=jnp.array(False), - respawn_timer=jnp.array(0, dtype=jnp.int32), difficulty=state.difficulty, - jump_cooldown=jnp.array(0, dtype=jnp.int32), - post_jump_cooldown=jnp.array(0, dtype=jnp.int32), - is_jumping=jnp.array(False), - is_on_road=jnp.array(True), - player_car=player_car, - step_counter=state.step_counter, - round_started=jnp.array(False), - movement_steps=jnp.array(0), - steep_road_timer=jnp.array(0, dtype=jnp.int32), - jump_slope=jnp.array(0.0, dtype=jnp.float32), - flags=state.flags, - flags_collected_mask=state.flags_collected_mask, - collectibles=collectibles, - collectible_spawn_timer=jnp.array(0, dtype=jnp.int32), - enemy_cars=enemy_cars, - enemy_spawn_timer=jnp.array(self.consts.ENEMY_SPAWN_INTERVAL_BASE, dtype=jnp.int32), - awaiting_respawn=jnp.array(False), - awaiting_round_start=jnp.array(True), # Wait for input to start round after respawn - input_released=jnp.array(False), # Require button release before round can start - jump_key_released=jnp.array(True), - last_extra_life_score=state.last_extra_life_score, - jump_total_duration=jnp.array(self.consts.JUMP_FRAMES, dtype=jnp.int32), - rng_key=rng_key, - ) - - @partial(jax.jit, static_argnums=(0,)) - def _enemy_collision_step_main(self, state: UpNDownState) -> UpNDownState: - """Handle collisions between the player and enemy cars. - - - While airborne, collisions are ignored except for the final jump frames, - where hitting an enemy despawns it and awards a bonus. - - On ground collisions, the player loses a life and the stage soft-resets - without clearing score or collected flags. - - Landing collisions use a larger distance and are road-independent (for crossings). - """ - player_x = state.player_car.position.x - player_y = state.player_car.position.y - - dx = jnp.abs(state.enemy_cars.position.x - player_x) - dy = jnp.abs(state.enemy_cars.position.y - player_y) - wrapped_dy = jnp.minimum(dy, self.consts.TRACK_LENGTH - dy) - - # For ground collision: only trigger when enemy position is within tight distance - overlap_x_ground = dx <= self.consts.GROUND_COLLISION_DISTANCE - overlap_y_ground = wrapped_dy <= self.consts.GROUND_COLLISION_DISTANCE - # For late jump collision: use larger overlap based on car dimensions plus extra tolerance - # "slightly more forgiving" - jump_tolerance = 4.0 - overlap_x_jump = dx <= (state.player_car.position.width + state.enemy_cars.position.width) / 2.0 + jump_tolerance - overlap_y_jump = wrapped_dy <= (state.player_car.position.height + state.enemy_cars.position.height) / 2.0 + jump_tolerance - same_road = state.enemy_cars.current_road == state.player_car.current_road - - # Ground collision mask uses tight 3-pixel distance and same road - ground_collision_mask = jnp.logical_and(state.enemy_cars.active, jnp.logical_and(same_road, jnp.logical_and(overlap_x_ground, overlap_y_ground))) - # Jump collision mask is road-independent - can destroy enemies on either road when jumping - jump_collision_mask = jnp.logical_and(state.enemy_cars.active, jnp.logical_and(overlap_x_jump, overlap_y_jump)) - collision_mask = jump_collision_mask # For late jump scoring - - any_jump_collision = jnp.any(jump_collision_mask) - any_ground_collision = jnp.any(ground_collision_mask) - - # Check if player is in post-landing invincibility phase - is_invincible = state.post_jump_cooldown > 0 - - late_jump_window = jnp.logical_and(state.is_jumping, state.jump_cooldown <= self.consts.LATE_JUMP_COLLISION_FRAMES) - late_jump_collision = jnp.logical_and(any_jump_collision, late_jump_window) - # Ground collision only applies when not jumping AND not in post-landing invincibility - grounded_collision = jnp.logical_and( - any_ground_collision, - jnp.logical_and(jnp.logical_not(state.is_jumping), jnp.logical_not(is_invincible)) - ) - - def handle_late_jump(): - hits = collision_mask.astype(jnp.int32) - bonus = jnp.sum(hits) * self.consts.LATE_JUMP_ENEMY_SCORE - new_enemy_active = jnp.logical_and(state.enemy_cars.active, jnp.logical_not(collision_mask)) - new_enemy_age = jnp.where(collision_mask, jnp.zeros_like(state.enemy_cars.age), state.enemy_cars.age) - new_enemy_cars = EnemyCars( - position=state.enemy_cars.position, - speed=state.enemy_cars.speed, - type=state.enemy_cars.type, - current_road=state.enemy_cars.current_road, - road_index_A=state.enemy_cars.road_index_A, - road_index_B=state.enemy_cars.road_index_B, - direction_x=state.enemy_cars.direction_x, - active=new_enemy_active, - age=new_enemy_age, - ) - - return state._replace(score=state.score + bonus, enemy_cars=new_enemy_cars) - - def handle_ground_collision(): - # Trigger death state - stop player but keep position (renderer hides player when awaiting_respawn) - dead_car = state.player_car._replace( - speed=jnp.array(0, dtype=jnp.int32), - ) - return state._replace( - lives=state.lives - 1, - is_dead=jnp.array(True), - awaiting_respawn=jnp.array(True), - player_car=dead_car, - ) - - # Ground collision causes death (landing is now protected by invincibility) - any_fatal_collision = grounded_collision - - return jax.lax.cond( - late_jump_collision, - lambda _: handle_late_jump(), - lambda _: jax.lax.cond( - any_fatal_collision, - lambda _: handle_ground_collision(), - lambda _: state, - operand=None, + jump_cooldown=jump_cooldown, + is_jumping=is_jumping, + is_on_road=is_on_road, + player_car=Car( + position=EntityPosition( + x=player_x, + y=player_y, + width=state.player_car.position.width, + height=state.player_car.position.height, + ), + speed=player_speed, + direction_x=car_direction_x, + current_road=current_road, + road_index_A=road_index_A, + road_index_B=road_index_B, + type=state.player_car.type, ), - operand=None, - ) - - @partial(jax.jit, static_argnums=(0,)) - def _passive_score_step_main(self, state: UpNDownState) -> UpNDownState: - """Award passive score at regular intervals after the player has started moving.""" - should_award = jnp.logical_and( - state.round_started, - state.movement_steps % self.consts.PASSIVE_SCORE_INTERVAL == 0, ) - bonus = jnp.where(should_award, jnp.int32(self.consts.PASSIVE_SCORE_AMOUNT), jnp.int32(0)) - return state._replace(score=state.score + bonus) - - @partial(jax.jit, static_argnums=(0,)) - def _reset_jit(self, key: chex.PRNGKey) -> Tuple[UpNDownObservation, UpNDownState]: - rng_key, flag_key, enemy_key = jax.random.split(key, 3) - - # Evenly spread flags along the track with small jitter - base_y = jnp.linspace(-900.0, -100.0, self.consts.NUM_FLAGS) - jitter = jax.random.uniform(flag_key, shape=(self.consts.NUM_FLAGS,), minval=-40.0, maxval=40.0) - flag_y_offsets = base_y + jitter - - # Alternate roads 0/1 for variety - flag_roads = jnp.arange(self.consts.NUM_FLAGS) % 2 - - # Calculate which road segment each flag is on based on Y position - flag_segments = jax.vmap(self._get_road_segment)(flag_y_offsets) - - # Each flag color index corresponds to its position (0-7) - flag_color_indices = jnp.arange(self.consts.NUM_FLAGS) - - flags = Flag( - y=flag_y_offsets, - road=flag_roads, - road_segment=flag_segments, - color_idx=flag_color_indices, - collected=jnp.zeros(self.consts.NUM_FLAGS, dtype=jnp.bool_), - ) - - # Initialize collectibles as all inactive (will spawn dynamically with mixed types) - collectibles = self._initialize_collectibles() - - # Seed initial visible enemies spaced around the player - player_start_y = jnp.array(0.0) - enemy_cars = self._initialize_enemies(enemy_key, player_start_y) + def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: state = UpNDownState( score=0, - lives=jnp.array(self.consts.INITIAL_LIVES, dtype=jnp.int32), - is_dead=jnp.array(False), - respawn_timer=jnp.array(0, dtype=jnp.int32), difficulty=self.consts.DIFFICULTIES[0], jump_cooldown=0, - post_jump_cooldown=0, is_jumping=False, is_on_road=True, player_car=Car( position=EntityPosition( - x=jnp.asarray(30, dtype=jnp.float32), - y=jnp.asarray(0, dtype=jnp.float32), + x=50, + y=50, width=self.consts.PLAYER_SIZE[0], height=self.consts.PLAYER_SIZE[1], ), @@ -1753,94 +282,15 @@ def _reset_jit(self, key: chex.PRNGKey) -> Tuple[UpNDownObservation, UpNDownStat road_index_B=0, type=0, ), - step_counter=jnp.array(0), - rng_key=rng_key, - round_started=jnp.array(False), - movement_steps=jnp.array(0), - steep_road_timer=jnp.array(0, dtype=jnp.int32), - jump_slope=jnp.array(0.0, dtype=jnp.float32), - flags=flags, - flags_collected_mask=jnp.zeros(self.consts.NUM_FLAGS, dtype=jnp.bool_), - collectibles=collectibles, - collectible_spawn_timer=jnp.array(0, dtype=jnp.int32), - enemy_cars=enemy_cars, - enemy_spawn_timer=jnp.array(self.consts.ENEMY_SPAWN_INTERVAL_BASE, dtype=jnp.int32), - awaiting_respawn=jnp.array(False), - awaiting_round_start=jnp.array(True), # Start frozen until first input - input_released=jnp.array(True), # Can start immediately at game start - jump_key_released=jnp.array(True), - last_extra_life_score=jnp.array(0, dtype=jnp.int32), - jump_total_duration=jnp.array(self.consts.JUMP_FRAMES, dtype=jnp.int32), ) initial_obs = self._get_observation(state) + return initial_obs, state - - @partial(jax.jit, static_argnums=(0,)) - def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: - if key is None: - key = jax.random.PRNGKey(42) - return self._reset_jit(key) @partial(jax.jit, static_argnums=(0,)) def step(self, state: UpNDownState, action: chex.Array) -> Tuple[UpNDownObservation, UpNDownState, float, bool, UpNDownInfo]: previous_state = state - - any_action = action != Action.NOOP - - # Track input release - set to True when no button is pressed - input_released = jnp.where(any_action, state.input_released, jnp.array(True)) - state = state._replace(input_released=input_released) - - # Check if we're awaiting respawn - if so, check for input to trigger respawn - should_respawn = jnp.logical_and(state.awaiting_respawn, any_action) - - # Respawn if player pressed any key while awaiting - state = jax.lax.cond( - should_respawn, - lambda s: self._respawn_after_collision(s, s.lives), # lives already decremented - lambda s: s, - state, - ) - - # Check if we're awaiting round start - if so, check for input to start round - # Only start if input was released since respawn (prevents holding button through) - should_start_round = jnp.logical_and( - jnp.logical_and(state.awaiting_round_start, any_action), - state.input_released # Must have released button first - ) - state = jax.lax.cond( - should_start_round, - lambda s: s._replace(awaiting_round_start=jnp.array(False)), - lambda s: s, - state, - ) - - # Skip all game logic if awaiting respawn OR awaiting round start - is_frozen = jnp.logical_or(state.awaiting_respawn, state.awaiting_round_start) - - def run_game_logic(s): - s = self._player_step(s, action) - s = self._death_step(s) - s = self._passive_score_step_main(s) - s = self._flag_step_main(s) - s = self._level_progression_step(s) - s = self._extra_life_step(s) - s = self._collectible_step_main(s) - s = self._enemy_step_main(s) - s = self._enemy_collision_step_main(s) - return s - - def freeze_game(s): - # Only increment step counter while frozen, everything else paused - return s._replace(step_counter=s.step_counter + 1) - - # Run game logic only if not frozen - state = jax.lax.cond( - is_frozen, - freeze_game, - run_game_logic, - state, - ) + state = self._player_step(state, action) done = self._get_done(state) env_reward = self._get_reward(previous_state, state) @@ -1851,191 +301,40 @@ def freeze_game(s): def render(self, state: UpNDownState) -> jnp.ndarray: - frame = self.renderer.render(state) - return jnp.asarray(frame, dtype=jnp.uint8) + return self.renderer.render(state) - @partial(jax.jit, static_argnums=(0,)) - def _get_observation(self, state: UpNDownState) -> UpNDownObservation: - """Build complete observation for RL agents. - - Reuses existing game classes directly. Extra fields are filtered during flatten. - """ - # Check if on steep road - is_on_steep_road = self._is_steep_road_segment( - state.player_car.current_road, - state.player_car.road_index_A, - state.player_car.road_index_B, + def _get_observation(self, state: UpNDownState): + player = EntityPosition( + x=jnp.array(state.player_car.position.x), + y=state.player_car.position.y, + width=jnp.array(self.consts.PLAYER_SIZE[0]), + height=jnp.array(self.consts.PLAYER_SIZE[1]), ) - return UpNDownObservation( - player_car=state.player_car, - enemy_cars=state.enemy_cars, - flags=state.flags, - collectibles=state.collectibles, - flags_collected_mask=state.flags_collected_mask.astype(jnp.int32), - player_score=jnp.int32(state.score), - lives=jnp.int32(state.lives), - is_jumping=jnp.int32(state.is_jumping), - jump_cooldown=jnp.int32(state.jump_cooldown), - is_on_steep_road=jnp.int32(is_on_steep_road), - round_started=jnp.int32(state.round_started), + player=player, ) - @partial(jax.jit, static_argnums=(0,)) - def flatten_car(self, car: Car) -> jnp.ndarray: - """Flatten a Car to a 1D array.""" - return jnp.concatenate([ - jnp.array([car.position.x], dtype=jnp.int32), - jnp.array([car.position.y], dtype=jnp.int32), - jnp.array([car.position.width], dtype=jnp.int32), - jnp.array([car.position.height], dtype=jnp.int32), - jnp.array([car.speed], dtype=jnp.int32), - jnp.array([car.type], dtype=jnp.int32), - jnp.array([car.current_road], dtype=jnp.int32), - jnp.array([car.road_index_A], dtype=jnp.int32), - jnp.array([car.road_index_B], dtype=jnp.int32), - jnp.array([car.direction_x], dtype=jnp.int32), - ]) - - @partial(jax.jit, static_argnums=(0,)) - def flatten_enemy_cars(self, enemy_cars: EnemyCars) -> jnp.ndarray: - """Flatten EnemyCars to a 1D array (all fields).""" - return jnp.concatenate([ - enemy_cars.position.x.astype(jnp.int32), - enemy_cars.position.y.astype(jnp.int32), - enemy_cars.position.width.astype(jnp.int32), - enemy_cars.position.height.astype(jnp.int32), - enemy_cars.speed.astype(jnp.int32), - enemy_cars.type.astype(jnp.int32), - enemy_cars.current_road.astype(jnp.int32), - enemy_cars.road_index_A.astype(jnp.int32), - enemy_cars.road_index_B.astype(jnp.int32), - enemy_cars.direction_x.astype(jnp.int32), - enemy_cars.active.astype(jnp.int32), - enemy_cars.age.astype(jnp.int32), - ]) - - @partial(jax.jit, static_argnums=(0,)) - def flatten_flags(self, flags: Flag) -> jnp.ndarray: - """Flatten Flag to a 1D array.""" - return jnp.concatenate([ - flags.y.astype(jnp.int32), - flags.road.astype(jnp.int32), - flags.road_segment.astype(jnp.int32), - flags.color_idx.astype(jnp.int32), - flags.collected.astype(jnp.int32), - ]) - - @partial(jax.jit, static_argnums=(0,)) - def flatten_collectibles(self, collectibles: Collectible) -> jnp.ndarray: - """Flatten Collectible to a 1D array (all fields).""" - return jnp.concatenate([ - collectibles.y.astype(jnp.int32), - collectibles.x.astype(jnp.int32), - collectibles.road.astype(jnp.int32), - collectibles.color_idx.astype(jnp.int32), - collectibles.type_id.astype(jnp.int32), - collectibles.active.astype(jnp.int32), - ]) - @partial(jax.jit, static_argnums=(0,)) def obs_to_flat_array(self, obs: UpNDownObservation) -> jnp.ndarray: - """Flatten the complete observation to a 1D array for RL. - - Order: - - Player car: 10 values (x, y, w, h, speed, type, road, road_index_A, road_index_B, direction_x) - - Enemy cars: MAX_ENEMY_CARS * 12 values (x, y, w, h, speed, type, road, road_index_A, road_index_B, direction_x, active, age) - - Flags: NUM_FLAGS * 5 values (y, road, segment, color, collected per flag) - - Collectibles: MAX_COLLECTIBLES * 6 values (y, x, road, color_idx, type, active per collectible) - - Flags collected mask: NUM_FLAGS values - - Score, lives, is_jumping, jump_cooldown, is_on_steep_road, round_started: 6 values - """ - return jnp.concatenate([ - self.flatten_car(obs.player_car), - self.flatten_enemy_cars(obs.enemy_cars), - self.flatten_flags(obs.flags), - self.flatten_collectibles(obs.collectibles), - obs.flags_collected_mask.flatten().astype(jnp.int32), - jnp.array([obs.player_score], dtype=jnp.int32), - jnp.array([obs.lives], dtype=jnp.int32), - jnp.array([obs.is_jumping], dtype=jnp.int32), - jnp.array([obs.jump_cooldown], dtype=jnp.int32), - jnp.array([obs.is_on_steep_road], dtype=jnp.int32), - jnp.array([obs.round_started], dtype=jnp.int32), - ]) + return jnp.concatenate([ + obs.player.x.flatten(), + obs.player.y.flatten(), + obs.player.height.flatten(), + obs.player.width.flatten(), + ] + ) def action_space(self) -> spaces.Discrete: return spaces.Discrete(6) - def observation_space(self) -> spaces.Dict: - """Returns the observation space for Up N Down. - - The observation reuses existing game classes: - - player_car: Car with position (x, y, w, h), speed, type, current_road, direction_x - - enemy_cars: EnemyCars with positions, speeds, types, roads, active flags - - flags: Flag with y, road, road_segment, color_idx, collected - - collectibles: Collectible with y, x, road, type_id, active - - flags_collected_mask: boolean array of shape (NUM_FLAGS,) - - player_score: int (0-999999) - - lives: int (0-5) - - is_jumping: int (0 or 1) - - jump_cooldown: int (0-48) - - is_on_steep_road: int (0 or 1) - - round_started: int (0 or 1) - """ + def observation_space(self) -> spaces: return spaces.Dict({ - "player_car": spaces.Dict({ - "position": spaces.Dict({ - "x": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), - "y": spaces.Box(low=-2000, high=0, shape=(), dtype=jnp.int32), - "width": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), - "height": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), - }), - "speed": spaces.Box(low=-self.consts.MAX_SPEED, high=self.consts.MAX_SPEED, shape=(), dtype=jnp.int32), - "type": spaces.Box(low=0, high=3, shape=(), dtype=jnp.int32), - "current_road": spaces.Box(low=0, high=2, shape=(), dtype=jnp.int32), - "road_index_A": spaces.Box(low=0, high=30, shape=(), dtype=jnp.int32), - "road_index_B": spaces.Box(low=0, high=30, shape=(), dtype=jnp.int32), - "direction_x": spaces.Box(low=-1, high=1, shape=(), dtype=jnp.int32), - }), - "enemy_cars": spaces.Dict({ - "position": spaces.Dict({ - "x": spaces.Box(low=0, high=160, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), - "y": spaces.Box(low=-2000, high=0, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), - "width": spaces.Box(low=0, high=160, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), - "height": spaces.Box(low=0, high=210, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), - }), - "speed": spaces.Box(low=-(self.consts.ENEMY_SPEED_MAX + 1), high=(self.consts.ENEMY_SPEED_MAX + 1), shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), - "type": spaces.Box(low=0, high=3, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), - "current_road": spaces.Box(low=0, high=2, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), - "road_index_A": spaces.Box(low=0, high=30, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), - "road_index_B": spaces.Box(low=0, high=30, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), - "direction_x": spaces.Box(low=-1, high=1, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), - "active": spaces.Box(low=0, high=1, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), - "age": spaces.Box(low=0, high=10000, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), - }), - "flags": spaces.Dict({ - "y": spaces.Box(low=-2000, high=0, shape=(self.consts.NUM_FLAGS,), dtype=jnp.int32), - "road": spaces.Box(low=0, high=1, shape=(self.consts.NUM_FLAGS,), dtype=jnp.int32), - "road_segment": spaces.Box(low=0, high=30, shape=(self.consts.NUM_FLAGS,), dtype=jnp.int32), - "color_idx": spaces.Box(low=0, high=7, shape=(self.consts.NUM_FLAGS,), dtype=jnp.int32), - "collected": spaces.Box(low=0, high=1, shape=(self.consts.NUM_FLAGS,), dtype=jnp.int32), + "player": spaces.Dict({ + "x": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), + "y": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), + "width": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), + "height": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), }), - "collectibles": spaces.Dict({ - "y": spaces.Box(low=-2000, high=0, shape=(self.consts.MAX_COLLECTIBLES,), dtype=jnp.int32), - "x": spaces.Box(low=0, high=160, shape=(self.consts.MAX_COLLECTIBLES,), dtype=jnp.int32), - "road": spaces.Box(low=0, high=1, shape=(self.consts.MAX_COLLECTIBLES,), dtype=jnp.int32), - "color_idx": spaces.Box(low=0, high=7, shape=(self.consts.MAX_COLLECTIBLES,), dtype=jnp.int32), - "type_id": spaces.Box(low=0, high=3, shape=(self.consts.MAX_COLLECTIBLES,), dtype=jnp.int32), - "active": spaces.Box(low=0, high=1, shape=(self.consts.MAX_COLLECTIBLES,), dtype=jnp.int32), - }), - "flags_collected_mask": spaces.Box(low=0, high=1, shape=(self.consts.NUM_FLAGS,), dtype=jnp.int32), - "player_score": spaces.Box(low=0, high=999999, shape=(), dtype=jnp.int32), - "lives": spaces.Box(low=0, high=5, shape=(), dtype=jnp.int32), - "is_jumping": spaces.Box(low=0, high=1, shape=(), dtype=jnp.int32), - "jump_cooldown": spaces.Box(low=0, high=48, shape=(), dtype=jnp.int32), - "is_on_steep_road": spaces.Box(low=0, high=1, shape=(), dtype=jnp.int32), - "round_started": spaces.Box(low=0, high=1, shape=(), dtype=jnp.int32), }) def image_space(self) -> spaces.Box: @@ -2047,34 +346,16 @@ def image_space(self) -> spaces.Box: ) @partial(jax.jit, static_argnums=(0,)) - def _get_info(self, state: UpNDownState) -> UpNDownInfo: - """Build info dict with additional debugging/analysis data.""" - # Get current road segment for player - road_index = jnp.where( - state.player_car.current_road == 0, - state.player_car.road_index_A, - state.player_car.road_index_B, - ) - - return UpNDownInfo( - step_counter=jnp.int32(state.step_counter), - difficulty=jnp.int32(state.difficulty), - movement_steps=jnp.int32(state.movement_steps), - jump_slope=jnp.float32(state.jump_slope), - player_road_segment=jnp.int32(road_index), - ) + def _get_info(self, state: UpNDownState, ) -> UpNDownInfo: + return UpNDownInfo(time=1) @partial(jax.jit, static_argnums=(0,)) def _get_reward(self, previous_state: UpNDownState, state: UpNDownState): - base_delta = jnp.asarray(state.score - previous_state.score, dtype=jnp.float32) - if self.reward_funcs: - extras = jnp.sum(jnp.array([fn(previous_state, state) for fn in self.reward_funcs], dtype=jnp.float32)) - return base_delta + extras - return base_delta + return state.score @partial(jax.jit, static_argnums=(0,)) def _get_done(self, state: UpNDownState) -> bool: - return state.lives <= 0 + return jnp.logical_not(True) class UpNDownRenderer(JAXGameRenderer): def __init__(self, consts: UpNDownConstants = None): @@ -2086,17 +367,12 @@ def __init__(self, consts: UpNDownConstants = None): #downscale=(84, 84) ) self.jr = render_utils.JaxRenderingUtils(self.config) - - background = self._createBackgroundSprite(self.config.game_dimensions) - top_block = self._createBackgroundSprite((25, self.config.game_dimensions[1])) - bottom_block = self._createBackgroundSprite((16, self.config.game_dimensions[1])) - temp_pointer = self._createBackgroundSprite((1, 1)) - blackout_square = self._createBackgroundSprite(self.consts.FLAG_BLACKOUT_SIZE) - # Build asset config locally (matches other games' pattern) - asset_config, road_files = self._get_asset_config(background, top_block, bottom_block, temp_pointer, blackout_square) + # 2. Update asset config to include both walls + asset_config = self._get_asset_config() sprite_path = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/" + # 3. Make a single call to the setup function ( self.PALETTE, self.SHAPE_MASKS, @@ -2104,423 +380,19 @@ def __init__(self, consts: UpNDownConstants = None): self.COLOR_TO_ID, self.FLIP_OFFSETS ) = self.jr.load_and_setup_assets(asset_config, sprite_path) - self.road_sizes, self.complete_road_size = self._get_road_sprite_sizes(road_files) - self.view_height = self.config.game_dimensions[0] - # Precompute offsets so repeated road tiles can wrap seamlessly without gaps. - road_cycle = max(1, self.complete_road_size) - repeats = max(1, int(-(-self.view_height // road_cycle)) + 2) # Ceiling division trick - self._road_tile_offsets = jnp.arange(-repeats, repeats + 1, dtype=jnp.int32) * jnp.int32(self.complete_road_size) - self._num_road_tiles = int(self._road_tile_offsets.shape[0]) - - self.enemy_sprite_names = { - self.consts.ENEMY_TYPE_CAMERO: "camero_left", - self.consts.ENEMY_TYPE_FLAG_CARRIER: "flag_carrier_left", - self.consts.ENEMY_TYPE_PICKUP: "pick_up_truck_left", - self.consts.ENEMY_TYPE_TRUCK: "truck_left", - } - - # Pre-pad enemy masks to a common shape so switch/array indexing works under jit - # Only use left sprites - right sprites are created by flipping horizontally - enemy_left_raw = [ - self.SHAPE_MASKS["camero_left"], - self.SHAPE_MASKS["flag_carrier_left"], - self.SHAPE_MASKS["pick_up_truck_left"], - self.SHAPE_MASKS["truck_left"], - ] - max_h = max([m.shape[0] for m in enemy_left_raw]) - max_w = max([m.shape[1] for m in enemy_left_raw]) - def _pad_mask(mask): - pad_h = max_h - mask.shape[0] - pad_w = max_w - mask.shape[1] - return jnp.pad(mask, ((0, pad_h), (0, pad_w)), constant_values=self.jr.TRANSPARENT_ID) - - self.enemy_left_masks = jnp.stack([_pad_mask(m) for m in enemy_left_raw], axis=0) - # Create right-facing masks by horizontally flipping the left masks - self.enemy_right_masks = jnp.flip(self.enemy_left_masks, axis=2) - - # Precompute flag mask data for recoloring without special-casing pink - self.flag_base_mask = self.SHAPE_MASKS["pink_flag"] - self.flag_solid_mask = self.flag_base_mask != self.jr.TRANSPARENT_ID - self.flag_palette_ids = self._compute_flag_palette_ids() - - # Precompute collectible mask data for recoloring (unified for all types) - # Reuse the same palette IDs since all collectibles use FLAG_COLORS - self.collectible_palette_ids = self.flag_palette_ids - - self.cherry_base_mask = self.SHAPE_MASKS["cherry"] - self.cherry_solid_mask = self.cherry_base_mask != self.jr.TRANSPARENT_ID - - self.balloon_base_mask = self.SHAPE_MASKS["balloon"] - self.balloon_solid_mask = self.balloon_base_mask != self.jr.TRANSPARENT_ID - - self.lollypop_base_mask = self.SHAPE_MASKS["lollypop"] - self.lollypop_solid_mask = self.lollypop_base_mask != self.jr.TRANSPARENT_ID - - self.ice_cream_base_mask = self.SHAPE_MASKS["ice_cream"] - self.ice_cream_solid_mask = self.ice_cream_base_mask != self.jr.TRANSPARENT_ID - - # Score rendering helpers - self.score_digit_masks = self.SHAPE_MASKS["score_digits"] - self.score_max_digits = 6 - self.score_digit_spacing = int(self.score_digit_masks.shape[2]) + 1 - self.score_render_y = 6 - self.score_center_x = self.config.game_dimensions[1] // 2 - self.config.game_dimensions[1] // 4 - - def _createBackgroundSprite(self, dimensions: Tuple[int, int]) -> jnp.ndarray: - """Creates a procedural background sprite for the game.""" - height, width = dimensions - color = (0, 0, 0, 255) # RGBA for wall color - shape = (height, width, 4) # Height, Width, RGBA channels - sprite = jnp.tile(jnp.array(color, dtype=jnp.uint8), (*shape[:2], 1)) - return sprite - - def _get_road_sprite_sizes(self, road_files: list[str]) -> list: - """Returns the sizes of the road sprites limited to the configured files.""" - road_dir = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/roads" - sizes = [] - for file in road_files: - sprite_name = os.path.basename(file) - sprite = jnp.load(f"{road_dir}/{sprite_name}") - sizes.append(sprite.shape[0]) - complete_size = int(sum(sizes)) - return sizes, complete_size - - def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.ndarray, bottomBlockSprite: jnp.ndarray, tempPointer: jnp.ndarray, blackoutSquare: jnp.ndarray) -> tuple[list, list[str]]: - """Return asset manifest and ordered road files (renderer-local like other games).""" - road_dir = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/roads" - road_files = sorted( - file for file in os.listdir(road_dir) - if file.endswith(".npy") - ) - roads = [f"roads/{file}" for file in road_files] + def _get_asset_config(self) -> list: + """Returns the declarative manifest of all assets for the game, including both wall sprites.""" return [ - {'name': 'background', 'type': 'background', 'data': backgroundSprite}, - {'name': 'road', 'type': 'group', 'files': roads}, + {'name': 'background', 'type': 'background', 'file': 'background/background1.npy'}, {'name': 'player', 'type': 'single', 'file': 'player_car.npy'}, - {'name': 'camero_left', 'type': 'single', 'file': 'enemy_cars/camero_left.npy'}, - {'name': 'flag_carrier_left', 'type': 'single', 'file': 'enemy_cars/flag_carrier_left.npy'}, - {'name': 'pick_up_truck_left', 'type': 'single', 'file': 'enemy_cars/pick_up_truck_left.npy'}, - {'name': 'truck_left', 'type': 'single', 'file': 'enemy_cars/truck_left.npy'}, - {'name': 'wall_top', 'type': 'procedural', 'data': topBlockSprite}, - {'name': 'wall_bottom', 'type': 'procedural', 'data': bottomBlockSprite}, - {'name': 'all_flags_top', 'type': 'single', 'file': 'all_flags_top.npy'}, - {'name': 'all_lives_bottom', 'type': 'single', 'file': 'all_lives_bottom.npy'}, - {'name': 'score_digits', 'type': 'digits', 'pattern': 'score/score_{}.npy'}, - {'name': 'pink_flag', 'type': 'single', 'file': 'pink_flag.npy'}, - {'name': 'flag_pole', 'type': 'single', 'file': 'flag_pole.npy'}, - {'name': 'cherry', 'type': 'single', 'file': 'cherry.npy'}, - {'name': 'balloon', 'type': 'single', 'file': 'balloon.npy'}, - {'name': 'lollypop', 'type': 'single', 'file': 'lollypop.npy'}, - {'name': 'ice_cream', 'type': 'single', 'file': 'ice_cream_cone.npy'}, - {'name': 'tempPointer', 'type': 'procedural', 'data': tempPointer}, - {'name': 'blackout_square', 'type': 'procedural', 'data': blackoutSquare}, - ], roads - - def _get_x_on_road(self, y: chex.Array, road_segment: chex.Array, track_corners_x: chex.Array) -> chex.Array: - """Calculate the X position on a road given a Y coordinate and road segment.""" - y1 = self.consts.TRACK_CORNERS_Y[road_segment] - y2 = self.consts.TRACK_CORNERS_Y[road_segment + 1] - x1 = track_corners_x[road_segment] - x2 = track_corners_x[road_segment + 1] - t = jnp.where(y2 != y1, (y - y1) / (y2 - y1), 0.0) - return x1 + t * (x2 - x1) - - def _find_palette_id(self, rgba: jnp.ndarray) -> int: - """Return palette index for an RGBA color, falling back to first entry if missing.""" - color_rgb = rgba[:3] - palette_rgb = self.PALETTE[:, :3] - matches = jnp.all(palette_rgb == color_rgb, axis=1) - found = jnp.argmax(matches) - # If no match, fallback to 0 (background) to avoid crashes - return int(found) - - def _compute_flag_palette_ids(self) -> jnp.ndarray: - """Precompute palette indices for each flag color without special-casing pink.""" - return jnp.array([self._find_palette_id(color) for color in self.consts.FLAG_COLORS], dtype=jnp.int32) - - @partial(jax.jit, static_argnums=(0,)) - def _jump_arc_offset(self, jump_cooldown: chex.Array, total_duration: chex.Array) -> chex.Array: - """Return a simple parabolic jump height based on remaining jump frames.""" - total = total_duration.astype(jnp.float32) - remaining = jnp.array(jump_cooldown, dtype=jnp.float32) - progress = jnp.clip((total - remaining) / jnp.maximum(total, 1.0), 0.0, 1.0) - centered = (progress - 0.5) * 2.0 - return self.consts.JUMP_ARC_HEIGHT * (1.0 - centered * centered) + ] @partial(jax.jit, static_argnums=(0,)) def render(self, state): raster = self.jr.create_object_raster(self.BACKGROUND) - road_diff = (-state.player_car.position.y) % self.complete_road_size - - # Vectorized road rendering: compute all Y offsets, stamp via vmap, fold overlays. - road_masks = self.SHAPE_MASKS["road"] # shape: (N, H, W) - num_segments = road_masks.shape[0] - - sizes = jnp.array(self.road_sizes, dtype=jnp.int32) - # Offsets: [0, cumsum(sizes[1:])] - offsets = jnp.concatenate([ - jnp.array([0], dtype=jnp.int32), - jnp.cumsum(sizes[1:], axis=0) - ], axis=0) - - base_y = jnp.asarray(self.consts.INITIAL_ROAD_POS_Y, dtype=jnp.int32) - y_positions = base_y + (road_diff.astype(jnp.int32)) - offsets - - tile_offsets = self._road_tile_offsets - tile_count = self._num_road_tiles - tiled_y = (y_positions[None, :] + tile_offsets[:, None]).reshape(-1) - tiled_masks = jnp.tile(road_masks, (tile_count, 1, 1)) - tiled_sizes = jnp.tile(sizes, tile_count) - - visible = jnp.logical_and( - tiled_y < self.view_height, - (tiled_y + tiled_sizes) > 0 - ) - - empty_raster = jnp.full_like(self.BACKGROUND, self.jr.TRANSPARENT_ID) - - def stamp(y, mask, is_visible): - return jax.lax.cond( - is_visible, - lambda _: self.jr.render_at_clipped(empty_raster, 10, y, mask), - lambda _: empty_raster, - operand=None, - ) - - overlays = jax.vmap(stamp)(tiled_y, tiled_masks, visible) - - total_segments = tile_count * num_segments - - def combine(i, acc): - over = overlays[i] - return jnp.where(over != self.jr.TRANSPARENT_ID, over, acc) - - raster = jax.lax.fori_loop(0, total_segments, combine, raster) - - def select_enemy_mask(enemy_type: chex.Array, going_left: chex.Array): - """Select enemy mask: left masks are base, right masks are horizontally flipped.""" - left_mask = self.enemy_left_masks[enemy_type] - right_mask = self.enemy_right_masks[enemy_type] - return jnp.where(going_left, left_mask, right_mask) - - # Pre-cast enemy properties to optimal types for rendering BEFORE the scan loop - enemy_active_arr = state.enemy_cars.active - enemy_x_arr = state.enemy_cars.position.x.astype(jnp.int32) - enemy_y_arr = state.enemy_cars.position.y - enemy_type_arr = state.enemy_cars.type - enemy_direction_x_arr = state.enemy_cars.direction_x - - def render_enemy(carry, enemy_idx): - raster = carry - enemy_active = enemy_active_arr[enemy_idx] - enemy_x = enemy_x_arr[enemy_idx] - enemy_y = enemy_y_arr[enemy_idx] - enemy_type = enemy_type_arr[enemy_idx] - direction_x = enemy_direction_x_arr[enemy_idx] - screen_y = 105 + (enemy_y - state.player_car.position.y) - # Hide enemies when awaiting round start or awaiting respawn - should_hide = jnp.logical_or(state.awaiting_round_start, state.awaiting_respawn) - is_visible = jnp.logical_and( - jnp.logical_and(enemy_active, jnp.logical_and(screen_y > 25, screen_y < 195)), - ~should_hide - ) - enemy_mask = select_enemy_mask(enemy_type, direction_x < 0) - - raster = jax.lax.cond( - is_visible, - lambda r: self.jr.render_at(r, enemy_x, screen_y.astype(jnp.int32), enemy_mask), - lambda r: r, - operand=raster, - ) - return raster, None - - raster_enemies, _ = jax.lax.scan(render_enemy, raster, jnp.arange(self.consts.MAX_ENEMY_CARS)) - - jump_offset = jax.lax.cond( - state.is_jumping, - lambda _: self._jump_arc_offset(state.jump_cooldown, state.jump_total_duration), - lambda _: jnp.array(0.0, dtype=jnp.float32), - operand=None, - ) - player_screen_y = jnp.int32(105 - jump_offset) player_mask = self.SHAPE_MASKS["player"] - # Skip rendering player when awaiting respawn OR awaiting round start - should_hide_player = jnp.logical_or(state.awaiting_respawn, state.awaiting_round_start) - raster_player = jax.lax.cond( - should_hide_player, - lambda _: raster_enemies, # Don't render player - lambda _: self.jr.render_at_clipped(raster_enemies, state.player_car.position.x, player_screen_y, player_mask), - operand=None, - ) - - wall_top_mask = self.SHAPE_MASKS["wall_top"] - raster_wall_top = self.jr.render_at(raster_player, 0, 0, wall_top_mask) - - wall_bottom_mask = self.SHAPE_MASKS["wall_bottom"] - raster_wall_bottom = self.jr.render_at(raster_wall_top, 0, 210 - wall_bottom_mask.shape[0], wall_bottom_mask) - - all_flags_top_mask = self.SHAPE_MASKS["all_flags_top"] - raster_flags_top = self.jr.render_at(raster_wall_bottom, 10, 20, all_flags_top_mask) - - # Render score centered at the top using dedicated score digit sprites - score_digits = self.jr.int_to_digits(state.score, max_digits=self.score_max_digits) - non_zero_mask = score_digits != 0 - has_non_zero = jnp.any(non_zero_mask) - first_non_zero = jnp.argmax(non_zero_mask) - start_index = jax.lax.select(has_non_zero, first_non_zero, self.score_max_digits - 1) - num_to_render = jax.lax.select(has_non_zero, self.score_max_digits - start_index, 1) - - total_width = num_to_render * self.score_digit_spacing - score_x = self.score_center_x - (total_width // 2) - - raster_score = self.jr.render_label_selective( - raster_flags_top, - jnp.int32(score_x), - self.score_render_y, - score_digits, - self.score_digit_masks, - start_index, - num_to_render, - spacing=self.score_digit_spacing, - max_digits_to_render=self.score_max_digits, - ) - - # Render flags on the road - flag_pole_mask = self.SHAPE_MASKS["flag_pole"] - - def render_flag(carry, flag_idx): - raster = carry - flag_y = state.flags.y[flag_idx] - flag_road = state.flags.road[flag_idx] - flag_segment = state.flags.road_segment[flag_idx] - flag_collected = state.flags.collected[flag_idx] - flag_color_idx = state.flags.color_idx[flag_idx] - - flag_x = jax.lax.cond( - flag_road == 0, - lambda _: self._get_x_on_road(flag_y, flag_segment, self.consts.FIRST_TRACK_CORNERS_X), - lambda _: self._get_x_on_road(flag_y, flag_segment, self.consts.SECOND_TRACK_CORNERS_X), - operand=None, - ) - screen_y = 105 + (flag_y - state.player_car.position.y) - # Hide flags when awaiting round start or awaiting respawn - should_hide = jnp.logical_or(state.awaiting_round_start, state.awaiting_respawn) - is_visible = jnp.logical_and( - jnp.logical_and(screen_y > 25, screen_y < 195), - jnp.logical_and(~flag_collected, ~should_hide) - ) - color_id = self.flag_palette_ids[flag_color_idx] - colored_flag_mask = jnp.where( - self.flag_solid_mask, - color_id, - self.flag_base_mask, - ) - raster = jax.lax.cond( - is_visible, - lambda r: self.jr.render_at( - self.jr.render_at(r, flag_x.astype(jnp.int32), screen_y.astype(jnp.int32), colored_flag_mask), - (flag_x + 5).astype(jnp.int32), screen_y.astype(jnp.int32), flag_pole_mask - ), - lambda r: r, - operand=raster, - ) - return raster, None - - raster_flags, _ = jax.lax.scan(render_flag, raster_score, jnp.arange(self.consts.NUM_FLAGS)) - - blackout_mask = self.SHAPE_MASKS["blackout_square"] - - def render_blackout(carry, flag_idx): - raster = carry - flag_collected = state.flags_collected_mask[flag_idx] - blackout_x = self.consts.FLAG_TOP_X_POSITIONS[flag_idx] - blackout_y = self.consts.FLAG_TOP_Y - raster = jax.lax.cond( - flag_collected, - lambda r: self.jr.render_at(r, blackout_x, blackout_y, blackout_mask), - lambda r: r, - operand=raster, - ) - return raster, None - - raster_blackout, _ = jax.lax.scan(render_blackout, raster_flags, jnp.arange(self.consts.NUM_FLAGS)) - - def render_collectible(carry, collectible_idx): - raster = carry - collectible_y = state.collectibles.y[collectible_idx] - collectible_x = state.collectibles.x[collectible_idx] - collectible_active = state.collectibles.active[collectible_idx] - collectible_color_idx = state.collectibles.color_idx[collectible_idx] - collectible_type_id = state.collectibles.type_id[collectible_idx] - screen_y = 105 + (collectible_y - state.player_car.position.y) - # Hide collectibles when awaiting round start or awaiting respawn - should_hide = jnp.logical_or(state.awaiting_round_start, state.awaiting_respawn) - is_visible = jnp.logical_and( - jnp.logical_and(screen_y > 25, screen_y < 195), - jnp.logical_and(collectible_active, ~should_hide) - ) - - def get_sprite_and_mask(type_id): - # Use switch for O(1) lookup instead of nested conditionals - def get_cherry(_): - return (self.cherry_base_mask, self.cherry_solid_mask, self.collectible_palette_ids) - def get_balloon(_): - return (self.balloon_base_mask, self.balloon_solid_mask, self.collectible_palette_ids) - def get_lollypop(_): - return (self.lollypop_base_mask, self.lollypop_solid_mask, self.collectible_palette_ids) - def get_ice_cream(_): - return (self.ice_cream_base_mask, self.ice_cream_solid_mask, self.collectible_palette_ids) - - return jax.lax.switch( - type_id, - [get_cherry, get_balloon, get_lollypop, get_ice_cream], - None, - ) - - base_mask, solid_mask, palette_ids = get_sprite_and_mask(collectible_type_id) - color_id = palette_ids[collectible_color_idx] - colored_mask = jnp.where( - (base_mask != self.jr.TRANSPARENT_ID) & (base_mask != 0), - color_id, - - base_mask, - ) - raster = jax.lax.cond( - is_visible, - lambda r: self.jr.render_at(r, collectible_x.astype(jnp.int32), screen_y.astype(jnp.int32), colored_mask), - lambda r: r, - operand=raster, - ) - return raster, None - - raster_collectibles, _ = jax.lax.scan(render_collectible, raster_blackout, jnp.arange(self.consts.MAX_COLLECTIBLES)) - - all_lives_bottom_mask = self.SHAPE_MASKS["all_lives_bottom"] - raster_lives = self.jr.render_at(raster_collectibles, 10, 195, all_lives_bottom_mask) - - # Black out lost lives (similar to flag blackout) - blackout_mask = self.SHAPE_MASKS["blackout_square"] - lives_lost = self.consts.INITIAL_LIVES - state.lives - - def render_life_blackout(carry, life_idx): - raster = carry - # Black out this life if it has been lost (life_idx < lives_lost) - should_blackout = life_idx < lives_lost - blackout_x = self.consts.LIFE_BOTTOM_X_POSITIONS[life_idx] - blackout_y = self.consts.LIFE_BOTTOM_Y - raster = jax.lax.cond( - should_blackout, - lambda r: self.jr.render_at(r, blackout_x, blackout_y, blackout_mask), - lambda r: r, - operand=raster, - ) - return raster, None - - raster_lives_blackout, _ = jax.lax.scan(render_life_blackout, raster_lives, jnp.arange(self.consts.INITIAL_LIVES)) - - wall_bottom_mask = self.SHAPE_MASKS["tempPointer"] - raster_pointer = self.jr.render_at(raster_lives_blackout, 140, 25, wall_bottom_mask) + raster = self.jr.render_at(raster, state.player_car.position.x, state.player_car.position.y, player_mask) - return self.jr.render_from_palette(raster_pointer, self.PALETTE) \ No newline at end of file + return self.jr.render_from_palette(raster, self.PALETTE) \ No newline at end of file diff --git a/src/jaxatari/games/sprites/up_n_down/backround/backround1.npy b/src/jaxatari/games/sprites/up_n_down/backround/backround1.npy deleted file mode 100644 index 6c353b610ae66a21a8991791f5875675ae42bf48..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 102944 zcmeI4O^cjG6otpFn{2b3-6%{JK_lW1aHoji!c_<+VnB>c)Qum1!XG@Qk={H_IDO_; z)m!!U5yEih+`4t|t*7hl^o;rU>+in$=KBw>{Q zbNAcbi(j6BXh z;V>4)S~E5qQ|#{{&r0T)y+_u&H$D6O!`S-o$HJKY+1cLimC^l#u`t$}x!IXwzb&4X z%(>bOSpV#LwfTmz_1=$#u_xuF(M(TC-P`}mZaIDwW)jCG+ zceuw?&7*5__hY>3XKSDJy1H8b=>5W&{#jV{Ou61YW5zR|tC>$}H$UUUSQraq{`)b` zjIXQFn)SH4cfG8q^?v8vW8AZMO=>^xuQc6Q z*8XQnWX9E7vCwPZFc!vov31P5B4fwek6+z;g7L>Z5Bph3DYE9Mk zswX!;<5hm;GoPzfe)`q?)%C1L=`i-F{C6Jr%owlEh(6aV-FiLqDILbbn9HbQivAuL zV^3;k%4fc#K7OvI)+_&*=b>L!e)`;pzS3Op>H}kHjjDk@^U+tD>s_t#D_`xW)+=A- zSH8+mpZVx3&GoJhW2e8@p=PMgMENSe@>PEN%tv2ou6K193u9sI^yiiC>qyP8x9`^L z#>(zee++)l!drzuf*56yRp4OW6yZNi@)qY{@^jXQ7SG78Ws?Yk|eAV@8zc3cY z!dU(N7++_cX={D_JS|`U|9ykQJ$7lcXw4M!UD`ZVzhSKXvr^6c()=(M#(Fa!#@1tu z&t|2`kcy|z^{zgLG481}Yf$m@xxTgLeyw%udYR9yxAl7FQ#y?09#hOJhLj)1!dMuq z|34jShP{1$K3COlz1*+5p81qM++!-HT2pnss!#du{;lIxeXj4;$Nbgx%%^l1YyYg| z8LBgIedbr1`Bgl9u6K193u9!8+GUFIdt{6=ay4f}?dD@V^Hp^tW-1OnYdczcYW?ZdcE2&jD@i(qnE%KXXI+m zNUi7R+A3c8%(pjvxW`uSNu8nc*?;SJHy`U+U5ByLe~Z9nLd7e8b%s>^VJwU(Mymdb zjH!D;`J5T!=~JtC`mW}Ft@Y@7&S7kHO=_lYJ&b35<-76B=W6cPT92+Lj6J^R%+;By zc>1nZ`&a$Z^@Oo7rkJZTy#ixu2FmA*7*GFwnxCT`HS-<6fBZ_{)yF)ozK-7iz?iCi zbWNx-%CIAb2YV& z{V<-stC?@KJ}}mMPj@DaXMM_N{OnrwL*La~_aDaE|18NpQ!>N#^j*z-qjeYyW4)PM zfib>rd(UL+{Os4&te@-Y_trfaQ$5jlb?=#MouBhyJ@j48e7$uT+x}U}GbS_2XS~wP z*IS3NFc!xA-yiEev(fo^FL8hRu4ca0y2m}n8rie*8Lu?+?M<_P_DA2<%-33nvH6~r zJYzDWe8wxye64jD3uCRBV!k~v#@8u(+?!VY(sy<1d5vIy znXjtD*!<5*o-vtGKI4^UzN!vmVT{aGnd16AGS+)$oZ;45^-JH?)&90%j5V`oDBpb@t`^V!SwDSOGoR8k z++(bfJu9E_N;BWyH2Y_N^j*z-N{6x2ds5A8cAw0!Uiz+PKBdE07z<;kzdy#;jXNX8 zGvD5{>X*K&)x2iB$5=CaRzBmEX1={?_Rs$4yPEl^!&v>_PUlQ#(`1JA(RVfTQHQZG z<}$^2GDW=tW2{;ANZ-}mZ*Q9OV1M*o&3s`@|DA#}TV0bG_Cw#*%ooPO7@2UHVm$MO zu`os^$c(F*Z$-w~181Oo#w*Qyd(-Tn{n2-I7(2bkI77~SZ(8+B-_>C(jFB0>hUmLG zjD;~W6UO|XVRMGA=FBfmyYpiF9_}%9=Iq(kD*xVncV3JSV|iAp*9e($br=g{VNAdO za$i?&ygGx`eRm#i{2uNxw`Vtgb%s>^?mXQ1FqUVf%ZwWz#=@A(j2piKV{T7wyn2?b z?z{7F2Dn_b)7z<-nhSd5nrk{n?8LAog_Er5=U)8Vj&+vXs z)mZhbGg9lf_N((z>s5bheHeRkuV$oXytS|DullM#l|PJyF~vw2%lj?Gi1Mopsr4i7 zG1X($S3RoyYW>!}s=w;1{#1UoK8)#mP|dKnugX^&xazC=RDQKy z`C%-Kg|YR2rmfDf>Z^N1Dvf%*iKx`0Vg)Cyk0uf{8zQ4zFys| z7T-QzudbHco1G%%Lo1RVm@o;mzT5V;@PbE-5l`qcrux6?DgXP-;MIR zzkj~zdMUN`$?ZDVHs;mv_Iq)!z8bIXd|tm^+&J5=;ZWn9%`V$D3URrP}TcXv-V?U&U(m$f~6iT7h>t%N`A8t1~8&Doy4#QR|F z|7Ta1r*;j_W$p6fkXW!5tflUqLwM@iGbYx$<}49QU6ZjKFX2&FjbJUFYj&@6aZ(}{ ztOaYqn)Y`9Tfa)&dtEi`nv8Yrn!QAyfR_RBhqdz$;H`(%yh&efcq z^%(aw_fvcOj~!<}xtjg59^;8jRwuOv zYr&eaL3UJaby90htZ@d_i`w(~|KW0V-+uJTTCO+T-?vvopO5qB>b~a*)^?wveS5E- NFPAOuSI?iD{|}RAGR6P^ diff --git a/src/jaxatari/games/sprites/up_n_down/backround/backround11.npy b/src/jaxatari/games/sprites/up_n_down/backround/backround11.npy deleted file mode 100644 index 06b3d1675bf65113926d862950b295f1d1d9aa46..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 37180 zcmeI*Pl}vX6o&C~>K=rem4?nB!63Q-XOT>TIFiP+Bp^n*H4~FTAY=pXAS-Z(W!FHd z{&m6eY2^Q-6Qzg=AO=a0TV+dg>o)!FvTFV41awjc5D z)6HhH_wn%E!~Kwt4aos zpIg}+vond--v!T(dyM9h&C9uu_-tL-d-B+?Q_r(bMeEDCa_*F8$3Kf4%I0fyE;|F) zAJ6;nJjz|abhh6t+rRbY-0k~ZG&A$Hv5sWfd01c0m2;6JIumiUUpZIKMUF&{nMeDr z!8v=*ax^3JwQ+V{u6G_$Ud~;A=gfB>(ac=Gbhh6t+rRbZ&?r#a!yCG ztW(yTFVQ)B%Ca-C-h6Frzgc!3)|<=aTz%)OGZE{M>&@kIuAI}Eh;_*I=5je#&go3V zI^=qD`7StT&)FGSzcjYrEZeX3<`FN!IeW~qGqB!#ZEU|;b{^K7M_kU`{;njN+0t?y z$+G=gZys?uSI+5>tV2=mdU=h`*<&>$~N9XJ@%g(@h^R==4X4!dI zZ=U6HZhe2dr6XC^A?wYvT+WqqkyBBA7oD@`-ZOIlxv~9b*?z1yZ{<1M=j<`d&cJ%} zwXywX*?CxR-pb|N@;_U3B+EKvy?HB_bLE^4m2>qsGLaKH5^=Pjd2gQLJ{Qd-%AJAh z&DX~En`P%=y?Jjg=a%=`sv}v}A?wY1b2(Sem2I}i6WADhd$=|5X_qLp>X{mjSaa;}^!=WhRv^W5I2y%W|v{Z_Rt3 ow{^d|&rN@y^EGUBEUNcd#;x;->ZhL9KZlt4N&o-= diff --git a/src/jaxatari/games/sprites/up_n_down/backround/backround12.npy b/src/jaxatari/games/sprites/up_n_down/backround/backround12.npy deleted file mode 100644 index ea76f49ffd82c144453aa08bd26c71bdfbb90aaa..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 43808 zcmeI4O=}cE5Qf*Q7q4D~y{$?DM*IMe9y|mT5kx$NpeqI>N}^u;073Exyn4zn@E?xZ z!$9Z3Y)wt~Oi!(^Fzmc}tE;Q3pPdZ}=GT+wkDtE0H2XaJvUt5bes#2XFkgKAw6oZn zFW$UA`gpW|xcC0(^>X#~NBi%NmuFu;e!Kr+d8V)5zqK{rxqEMGe&@#4{M)?8zss}P ztVyrK-w&Ib*TKP+#&fY&uBAUUZPiQP(o^e?^`xw=?&*)Mz4K)J*!7m)&+3Q1rAybB zo3A!+*1YJmU(09w)V0+QeM?WRU-4G=`v>wsek`By)LdU%i#PE$CWFy3o3 zl1IG7;%#?#x5>G>FfnAz*u3S;YPELASwzSkeSYFi&$lt(XW~t~iMNZtM;UWA>-M_4 zceSZ;<$j*ls5#w}&$rYY;y(JkhM0*bXY~1rw>r-$4`XSpc=MfC^m(OcU*kUd{PcYD zyPuqml}+vdx|&Tc}1UBdiFK$qt8#zH{biEWo+Hg z=7#igc^k5}F1g8Dmo+&t(dU)Zdx^Et=O^Cu`>Yc0gRVKmn|SNWN{Kk}Cf-VL)HR2A z6K`EvDG?{$#9IlDy5*aprjecnzeO{&K;msF) ze&Vg)?-#tBl!_N`zVnJcuTtFf%@cio(VlOdci(f{ly#@4H=CSm>^Y3h8~e+d)oLyC z{ck$j{KT8@`+=$PEZ%(Q7j0hR&3C?>SM=u_d1Ze?>2-e_Zb*;)axeN`&3vUbc`Kct z`PjeZGv3n7w`rRFvp@P?&3vV`c$@a$Q{{P+AIoRFrJ1j^7H{G$Ww10~%(rR0rL$Yz zv)|~N{ZX?&`d-a^xw!VNYjpur5=CibT+w9L%ix2ukcK-i2aXwSiJa5)l+GjpXvmR>kCf-seES@Ys&GoPQXuPop z_HX%&w>0zB)~uKHSiU!&`MjF@iMR95H_nsus;#a1ET8pT@$|i#`-wO4mNG;>Qm@ZF z^NBa{mU2WsQm@ZF^NF{9ylwwDXlh(MbJk$_-gxHoYVJqg=$F>4m-#K<8_#@R&HZX? z*3Wt@-y6?-Ud{c)+xq7l=T%$tyjhRsd*hkUtGSi`ng``Q+tz^6UH2WICQ)-ps$v z&#%vJ<`=W&`>)Tx+|KUazx{mvb#|wpym&qxADtXe$Hz~nwz8;Q7qs8*N{`ab= z>+0&!V%x>_Zj0T;W4@@~rN*`x*4q}l^cY{dHfJlI4?QQvh4NWH8qfPn)3tEMzVRHC z&v>Pok6PWYe3g%W?YdOXR9|_HjHe&f%tsx~$9R>Ge(l&deE^NTJP{ zIZMG6a}I^Gwm*MG^Re$k)8;JAJR^rgK=irAc=nn;bu^y$Q!`)RbYJEf z&n>EXuIm11d^8{P*RIXk>bc2tjB=skmCyQp)8@>aReSk{is&;YxB*SIm^MUIWuSGZ2kMQ=yQkh?4|MIl`x2;t@ z`bxLGzTElD+4^UVHhWKQn47-RyuUVW&dga3X3d#7GiPfCqR$n^=ji3O8frcCmEKmX zM6c9&rus~uy;VNrQ){&z`bwu>Z`=9J*?MNS-P_H4*;nN=zH6=4OJC`(_oM2UGS5^W>2qJoXMERMt(U&iUGGQLH)s3JTh&+PlN-j< zSDN>irp;MDI3tJVtodBoCu;Uh`HU}JtNqYdy0ravHP2L^=(BIiXME{e?T5b7rR~?8 zHP6@F=R|H;pT5$(ziVyI`pFqNG-s>-XUM)N-F4qm*H`do(Y~m7wXfQJ zwg0H!bwA8m8P82}5ame4GkIvO*vJ6l;8DU%~?0|DSGaTBjrcywH1gh6opMo!z;KmNT#qvC@3jV8W2LFph6ZgB2ttf8)*1>30mIOB*n`1hpe-g z-I?7PpGd3C@wt0v=H72-6f3`9y?y!m-Mz`T$@kgCYV%<|dp4c@__~o4o& z_4}Lk#p>?!7t5>7>h|-^r{(9>t$z4)F`q8Z&gRo6kLJ^#)BpInKbcIn&2|0f`L?F( z{Cu+Azqq!2FU9X~em&mqUmRoKJB8mH&hGBDtylGye&3q)WPbX!YqNH^_mPJ#n(Qza z{o1uzGi$jRubDO0ko!t*&H0>~Hf!njHL`qaY;nD2&8(4$RNAbWwNz{|uURu|WFnO| zYi2DKTg+?L%o>?UrOldIOT`xR*325ebBJoziaMH)@yu7dHfzWCHPw5v6V)nzZN6DE zYq^*%$nLKl+PM7o<6mTrytcEuXJtJc7KDA1=U*>Pd}>F_@((~ z&8+2O)~uN|v$kWPsYUb*VtkXQxhubkYkODu=tp(id5kr#q#NVs)t@5$2V*F&QR8#>?oh{N;BWk+N_zi zTnv`P8f(J$nto|ot%rV84_%L(tVL_Xc-Bh!j32vJ>!BajL$BYg4R?mJ_hd);j8~fZ za%;0@)^agd5^H>KbJb|5anyS0M|JLXrDQF-cE+<7%4dA(TCJCUROhabSsTw8%HES5 zYe)(>p(xMnXhl%Cu`BQF`jEvKI3cGsvh*CI(7Zb z+IY`U_MYr0pYcjFU*FoSnYD(&zS&|vvsMmk>|Mi1b60*7*Y?hQU3C}M+Rph{9OuqB z`qqc+leb+Z-KC;8H=tnj4b=79gtQ}*pD_eZttd+|edr4-L&v>Po zFRCM~sktbhb7MSx>)3ms>PbJUyY|DZ#roN`-YK6aI~xn>Dl6mboZfj5llb Ru*TkAUOw3Nc`<)f{{<4M{W1Um diff --git a/src/jaxatari/games/sprites/up_n_down/backround/backround3.npy b/src/jaxatari/games/sprites/up_n_down/backround/backround3.npy deleted file mode 100644 index ee7c16619682eb58b3f417f6eb3b8f43549aaa7d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 37328 zcmeI5&1w`u6ouQZ8((2|qcDO2+5O~2>J8s?aSBi?ls?=AM^9&`on7ed^Z32?P$K3%|BkP zzOK$L-(RiHm)++t&OWb~*PpLHoqbte>xa*tEM`Z?$BWt1M~m67*)6{AH%-&_ugkwD z?a*~`@t}=)4Xq__crOm=Pt3jQ<~n4}eKCE`y8C@`+s|{#eVdcM?`JZn{Ss^Y_o)Vk z#oGV(a1HxeQVpl4O&fBB)apIgLh|g!hUBWT;b-ua`zO}I_phX9K&;)Z{hpFdu_o5W z;%`c}Sd;f1%JDhX!rY`!f_@I@_BLh@#ai8GYEEZ?SQBeQJnCGln~#wF9pdH?k85)b zeUJ8@Yu4KIz1owzVsAck-hMNlJ+*ws4{5P>bH9fAmAgkayqfiKJbjNAYhta>6WQr| z-93Ge7Heg*Mi#u9tZ+R25bb4+YgxI`_iC;uz58r<`^k9r)AAYb(PC|P?>xQBA(m`- zHEZB_`XO4ZiM2jIVy$c*`g<9Azu9cskc;bcF8VQb=()yRtla5)HP^dq+S@P2vtO3a z_z*4D$~rU2hF6mnj;9|}i#4&<=Pf2%oU1g}`hDr{Sx;@6^?Eh?!SVFR)UjD(u2yaI zy_)OQroH`OJo{n!jE||sTHfzavf$QBwkEz94Uhhz{;nie?pOemwHTK5pEq$-%da<=PAI38u%V&H}TC9!x z4ka62O;$Lbe$HB~iM3RGmd2X(T+{dJ(7y9;dA*wJb3FZ$bndLNmsaoTdo|bdYVUfC z=X#dU_?)#^+x3cQTqxR0ncxyiTC2O&^ zd#+h5aIYEf)z*CUOV(mdtfk^ttcf+TcDvuP@*T!L(oe0eeCXGtOF!4R7i7os89%nR z@}Xao7Hh-r(Aeye9nMd`CN0*)S}LAvV~u^Y&JFsp^_$<1+y2G*`(Ew$WUaP(t-Ylm zTeF{>pMK3+td08)B|BoRX}8R=H>GQ9Z|TR@>>uZ+U$fpaw*6jntzB*0>(sjT{tngm E53}cxUjP6A diff --git a/src/jaxatari/games/sprites/up_n_down/backround/backround4.npy b/src/jaxatari/games/sprites/up_n_down/backround/backround4.npy deleted file mode 100644 index 65d9d322f490b9d5430979b2e88fc2e795ab12b9..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 46944 zcmeI4Pm3H?6vf-E8^1#DT47uSqJm2ol8qvQ3s*9jkpv`>8Fk~IFX7hjI;I#N3_rR& z@63Hw(_M1{P36?N_rCk?`*n5bN%GIv-+lGX_wVg~-~DlTd3E#S_2KjV;m_Zm9WM5V zpWa^odVTrkhqu=+ucrH7UjBS@_3r-7tIJ=m-sulN`}AV}?33pg`{y5B?El(7;Me=R z-R^dJz4`Zxyb3dT{rZDj=QVoXT-&W*%~$I^l(TAI)qJ%t)_3(|*H`P)=lvR3J^VNFcESyb@R5=Z2$2sC$sny&V?`l=w^|}7q`D(qbKF_#s zp2OPu+rK})eaf%K+16*UHdkD4>+@Cn3uh1eTqQ?Jt8>`8AI`$r8jL>`&Nxf;{Hgj} zZ?ty%FEdvua<}e6U!0D_Cp=c!r8<`ID6Qi(Kt6(^ZnR#`~K0b z$M{y~sa)IhY-{`Fxw=}-!>z}7Rex)KI6MBism__)xLVcUnjg->*&0l5g){eg!g$V_ zKD8Uqd}`KHn)ynzo{Fc>d{_5!#&wluA1a&ps-RjI;tM#h+bA8R-Z=B8f?S8edYJTgZUEO+&Z*|&L`7N$pRXlxHw{_;O)q2(Z zxxVJ^hqJq%H*oHA^J@21JbhQUb>^_s@m1o>}}~4CldE z`mR>>=(}3=!+I*7el?%YVAjTJ^(vDxQ8dpY^E2SvZ@R2xrItoX)vW^ZlfJ#?R^YJ({|5J;u{_HS3+aZqKt* z>(iY(#UKTFqVKfRnS-IrUR@!Xg48Nakv{m^%{+yBz_*-toI=5v#C=D8}L@k+Cv(yAVP zRZschES!<4V_I>f{Bz}u`{HcqkJ5ABtGV~tFZ1cUn)Rqh*JFL^#u@KZkM7H@$9UBr z{ZrHKei*N`+ONu|uj(m3oIUK{yQ{e>U!5a;SG()0`l_Dt!&x{}j7`4cNcrc&8TY~2 z&_6e=_Cw#*Y96eo;$6R&GuCzcV!W%X>ofnTHh_Yjw!L+|@_)c0_z4)@D=`mScZrFDCrjb6vyH{;ol@)>{b zTJ4vzLd}St!p(8`mSDje&OsiKR0>qu?s%CI&}y#d_h)nWe{PhG2d(RcOe`G&L8__>pNuH=UK^j*z*RUOX4*)ayI zTroeKohN5JN3x=P#w*QwuI}Yb^`U(B#d!MB9{*~d^j%#&k8tLlXLa67<&zt(L*Lb` z=jw14&U%=0xng`c+XiPm*B*A97}j_7TAXoDZl8(w-VdC2RafU9&b;%h&U>x=Dp$;} z_8rc`8JRk!!`U`Dsh%tIU%c4eR$i><>fYyBrN=4P?JMzP-#?Fie&KBFd7kR~KJU5{Kb%H_*h zkzdJf8WjS48ml&HMAWA0Bl-x}UR))#lTB_Nt%#`hGH-_p{H}>znoR z>f`nLVzrxpz5Kda-Og{mEWfR8?Wc>E^ZsP9nD;N9&HLYd!k@=o*KOsw`ggkZJZER! z*7JDwX6&h}_Pf57uAy+Y8@+pZI2+=u-t)B{-+KMYn(r=ei9fOy|6OgVZ&S#nHTFFB zT5D&S0v|cg&{}`U+H}65$SGO7`=(-TXy@ggtdX@^@6_exldb3R=56_gdh*=GdUIvG z@7}B3KeFcQ-Znhph2l3o)7dT_tN|jjT`r zBdMc@o0Qm`tl|EtHIO*eqkaFqkUFJ~L)Vk6-S6Mf_4n+T%9zA?_D+cv&$Pt(`x&p0*!ZYxf&!WdZ%Q-lM-}YkJx` zStDy2W3sNXq^*ycHK~D&+F4s)mvf|!*7I5SvFf9LP1Z0M>)FvC^~Pr7d91eeyruOq zAF|fYH`LmD^vB+6HpXkVrl+lwHL|8LChHnY+WM4OlbXn=oxS$DRWBhw>8^(9`w(Bp3?ect==~jduz4HTE5;{Ymff9dy!iY^R?Qg^U1A0<$Dc# zm#cQ~{yg8NJeZHwPT8}in0?k-%Ii=r*5081Ar-M4 c&l+39kRR%y9g2&cM?J5N`&#w9@!Ed>0ih7rLI3~& diff --git a/src/jaxatari/games/sprites/up_n_down/backround/backround6.npy b/src/jaxatari/games/sprites/up_n_down/backround/backround6.npy deleted file mode 100644 index 0642e3cff1f1ba56d35e673c11980f4dfe92951d..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 34624 zcmeI2zi!k(5XMbQ!z;Kmplc!}6qFPw4Ty$<3Z2LyB1H+h0~+8dXsPm|CN8Y3zHs)A zcV_?ji)7u-d3HSW&Cg>e%AdFI-@N@IePPy6m=+x_}}(rw!A^UeO-etZ4# zW`D68zJI;FI_&PgKYZDK-QC$w&t7cWle5!Jd-`nC{%#liJZ_q%ACK$5ullLu_V#%{ zb>QA9-{IVx|J-?uFZIX$b2i48+S2%;{+M5Ayt(~T=SRkdPs|JLsW`(qZGVgxYR~0A zcRuv`rSG59XX?C!vo-59Xjppvq4`g7Ge4B z=FF{g$=X-{TK&$?n?B@lb*Y)N5RHbtu1;%QYwAbNj^DFuWXr9YoK63oZ`4Xtr!nT5 zJ~j2LX`X57vL>F!S}pY>XZO#|QSY^Wk~2OhW3W%o$Qe2FW57Zu%lpqF-`rxx)UV|C zvzR*N#)w8fG4&&7>E^KpDJ-w?T=&M*ucUbvQNPlzLOV0dyE!8QP{-(-9su%ie=AH9A!@8xmrcNcs zlIn&2nl+O%zgeoOd8}B-jnvl6lbn&WWBie``~T%}HPY0j#8|EtOV*Y(=QPijXyfj& z)TxBnt`eQafec zF*%d!g8nIW%Hi6mhty74134?_{Y=Ug`ja!hkNGfw+(>PoYgqYko^hU8vE{}kXXMP~ z!D>y;$e9&eZd`Ik&Ria>*5s@l&X57=?*ZtK@hoji&af`2ZCQ5>v88&Ve@qR?8RuCH z-pCm_BWL9@5Tkc^4SX(2ZH%XBYtA!iUFfe_*BWD?o>JRVGjdkX%#?D4{^YD47A$*- z+(>Q9+&ItbWmw7?`fG9*!WrfzwKeNttcUlDv>)iNsVO;o*dI-OQ^rDWq_$@6_Db@iul~K1vVY`Ezn{6@BPlui2e0iZE&u=k diff --git a/src/jaxatari/games/sprites/up_n_down/backround/backround7.npy b/src/jaxatari/games/sprites/up_n_down/backround/backround7.npy deleted file mode 100644 index 259a34f810ef7a6bb7bc68472c654c527964e2f2..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 47560 zcmeI2&uSDw7{u4BPqDXEiHAT$@R)-)#fyk12`1t}5=qpHdhr#!iibRbZ@6X%i}A;V z+M4e9x0JA$n)>z?2AxWu2&Up{;B`p)9Z;_K@8dA8T?eqR>Z@GGRcK+%7 z==9Cm`SD3}{prz%i<7JC7w?Zgo?PjBk00+X_a7bXEf4PREx#{2{JFbWEY{8E;j2IS z^tcC?mk-uAKiuz|UYDPlH#NO2XIwkiLEqKP+cnMo;QHvhnt4;x;cS?5oZOHD`mSc) z)O0utXG1U<&cazZ8-{@)_F26L*@LQAKJ&)bUC%SFlk1@GYUa&dtMzkzs$Tib8(W97 zVa`mK3stZDvAN>>;Vhhyi9U2Vn+9h*pLnn7yPA1J>sp*~-CPHKS2J&Jx^%tnc-D`- zU*Rn0*$@ndvv3y9{{P+c&3C~NdsB0co4q$bGn+oOThH;-%ZXWY%>u?s%YBSZA7tX?28yB^&g|l#0o2jyYFpcsuBm?ME6ug3dgU{(wjPQzu9NGa?`q~vO>;f$ zkG|5S ztXGTtFj?@h%8eOJ3&u|AxIvu&6oH?9t6 z9%qNYf2_NFcrTejpStgQ&PUDpUA^H*=u>k(`qXYc$5S)U)!{6hk*PNAa>e?&aK^sd_QJWwyuLMi<9zg~ z-TK<&RUZ4Jo+@W*UG!beedt@WKh8&=+O4lWUgfbr>TtH3?@h%8eOHqsrNdb`a~W&u z6<74bnaA1bufuhhkGhx4WWDlPPhC1*Zx+S;AJw4VJlPx-7@nt9aW?E3%ga$jm|_db=@ zvw!9(pY=*Jk2;)%vxW)9k@CqEbvO%W;q2dcK|W7>Ug*1;d9}6bhx3)zt32ha`ROx{ zI_Fu=vxW(B<7#q69nQj8IJ^D5$>)vF4}DiNuT-n~nO|C8J5TjPpYzZ4JX7nU?`rN- zsaEqdzqGz~p6Z7_=MQJs?@hH&+&5PYd_DpkGyZLW*#;3=u@kD<-2*TcXc=mXALthm*MRCUF+_Hs^@)hwVPjB z&;CmD)Oh933tr1 Syy8gt;VhhO!?5B=`F{cQg9AcR-q*Nn{ogFNp)N1F=Df6|#s?7KlVPumQHY05|XnxS@%KMk-p< z)z#IH>Bk>wZyM|*RWB>Onn^w7>Vd?wXbefw7V_;0f^BhaIxn3Bv)i0gb zJcpKIq&m&{fid(k$#-Fn9qAkx17lz;&hHP>@0`s~@*GyIDRB+oJL>_jGg$9aK!nv z$QW0zWR26Lhc$^toR5!2we~Tte#x4TMlJeEG~#@GGy!9Aeyj4)yaZp4AIX|e4PXq6 zNtiNyT4c#LXPAEE&m>*HN%GUGc2WAnc` zQXNAyY9W?tp|L|WXtj@pXwpKgh9epiQ-e|&qxEX+m>SiXOQRW$iCHrkYx6utdC}Nl z4CgNw0~)Sqj4Qv-dMka*r{)%YDIXd;r6#2?rqPAQq|~T}+*AvVow7zS*2Z~E!xfDI zV>o}g7@)jpY*+q}^;-K_NX;$8YB-`XIW<^{Fd@t%f$fYMPs}2r(8!LE3=PjIHEBrneok?iZNQR#`dkFHGhp}G^Qjyz!>^iY97HD z7z1NvG7$T@(D$6ij(tBd^Tq088eM2i%$jP*!<4^f3)bG-gR=eD~HeM(fqszW3RhzeY0}QUz!(@y%_SJi!`R_) zehj%%bC5E>l(7(fPO&L#s3CXH`hhX@vDDvDFb2lJSRMva>x_5h;5;w&tWga9de#r` T$~=7%jG>Rk=v7>fo*DZKUKXOw diff --git a/src/jaxatari/games/sprites/up_n_down/backround/backround9.npy b/src/jaxatari/games/sprites/up_n_down/backround/backround9.npy deleted file mode 100644 index ca78ccff5956da850b147227da97d14abebac977..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 31748 zcmeI*L2evX5XSLg%N>wrH-OB-65enCcA#t!VudVX!~zjy1Dhn1%~f4F=1^@rbX|Gaz0pWj?xZLe=`uC`x)akYKD{fJ+m zZZ@08>H6x&{~mi@Z{B?Nxa4K6daFn5emUCTTeqH8wqE_zBX++$!H zORu+{EL)HJ)w8UhTv>ZMpKE5vEU0H$Ke@72S(`GE&6evYpSgQwt+F;{!|Y^v@|nAj zlr?+E&SLcSZ{L3Txa76=Jl4$Gl5^U7tv#z#pEvsaDr<*#=&7?5onKiy{#lJat0)V- z_0i{7pKHheS?Qf6{moXSyFOz5dUMaL*{i*%XIZ~XW9MOi>RHyWH&@o`bFFvACjZJ> zWo^nr@B7GE^PSOqR%`p)pKLwqS=Mi9UNdXu5anaqQjxj@>u6PQO+@W!CJ`UevR!-)P+0 zZ+5@ym!4N;ZQOTglnwPPSJsaIe>F?bYHh#PY-IPle(8QM`E#xJ?3R8O`dPmmt&ep5 zBGxazC!qT4gQD*raE( z<@%AbX6@4qKK8EPxpCB2RImG8ubyT7B0i_AS*yKAy6Yp>@7%bxe)rp-dY1KzxUzOR z?@*t$*@<-5N337Om9@%Rl&LvA%9i>_S+i$*QO~k|r{<{sY<>EqE$>*FkYqj@C zcYVbAotl@fS3mnxk9eP*_t)qrSJn>y=H+Zo&F??Q5B)w{pC0a4kGQgS$WAmXb!UHS zwmvyqU$kF6;!Dn&vvn4c?)r%JTbp~=r@#HF%VteozmeXd)*khyo@LKlo1=PL>(O67 zb-A*3_#TF{$+ELwo1=PL>(O67b-A)uSu-}A{Jzqar zfAuJ9A6mcPv%mCiWzUk=_V@fOx7M58AJv!jbA8s|^_8{5_dK$*b(XSzz1j2S*7Nm~ z^;fT~Ro2YN9Ge~4ed?99L)M&4G&|SJ?(fZ>FSnkrU&NkQS*xs>i8+q4rLIq9t+Hk& z=GZJntWRaFvSudcILel~K2IrY_MT;DDL?g}rx4B4`9!?W=5$5gm!?OvQNOhQTl;Z7 n5l=bW-&@ZfZGSWy_0jt4U9a_+pCvZek%)KCkM From 16568548fff2ed9f24cbd015dc724c88f2e12d75 Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sun, 16 Nov 2025 13:34:26 +0100 Subject: [PATCH 43/76] use black backround and top and bottom wall sprites --- src/jaxatari/games/jax_upndown.py | 35 +++++++++++++++++++++++++------ 1 file changed, 29 insertions(+), 6 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 4d63a6455..e1ed585b7 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -270,8 +270,8 @@ def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: is_on_road=True, player_car=Car( position=EntityPosition( - x=50, - y=50, + x=30, + y=105, width=self.consts.PLAYER_SIZE[0], height=self.consts.PLAYER_SIZE[1], ), @@ -367,9 +367,13 @@ def __init__(self, consts: UpNDownConstants = None): #downscale=(84, 84) ) self.jr = render_utils.JaxRenderingUtils(self.config) + + background = self._createBackgroundSprite(self.config.game_dimensions) + top_block = self._createBackgroundSprite((25, self.config.game_dimensions[1])) + bottom_block = self._createBackgroundSprite((16, self.config.game_dimensions[1])) # 2. Update asset config to include both walls - asset_config = self._get_asset_config() + asset_config = self._get_asset_config(background, top_block, bottom_block) sprite_path = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/" # 3. Make a single call to the setup function @@ -381,18 +385,37 @@ def __init__(self, consts: UpNDownConstants = None): self.FLIP_OFFSETS ) = self.jr.load_and_setup_assets(asset_config, sprite_path) - def _get_asset_config(self) -> list: + def _createBackgroundSprite(self, dimensions: Tuple[int, int]) -> jnp.ndarray: + """Creates a procedural background sprite for the game.""" + height, width = dimensions + color = (0, 0, 0, 255) # RGBA for wall color + shape = (height, width, 4) # Height, Width, RGBA channels + sprite = jnp.tile(jnp.array(color, dtype=jnp.uint8), (*shape[:2], 1)) + return sprite + + def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.ndarray, bottomBlockSprite: jnp.ndarray) -> list: """Returns the declarative manifest of all assets for the game, including both wall sprites.""" return [ - {'name': 'background', 'type': 'background', 'file': 'background/background1.npy'}, + {'name': 'background', 'type': 'background', 'data': backgroundSprite}, + {'name': 'road1', 'type': 'single', 'file': 'background/background1.npy'}, {'name': 'player', 'type': 'single', 'file': 'player_car.npy'}, + {'name': 'wall_top', 'type': 'procedural', 'data': topBlockSprite}, + {'name': 'wall_bottom', 'type': 'procedural', 'data': bottomBlockSprite}, ] @partial(jax.jit, static_argnums=(0,)) def render(self, state): raster = self.jr.create_object_raster(self.BACKGROUND) + road1_mask = self.SHAPE_MASKS["road1"] + raster = self.jr.render_at(raster, 10, 25, road1_mask) player_mask = self.SHAPE_MASKS["player"] - raster = self.jr.render_at(raster, state.player_car.position.x, state.player_car.position.y, player_mask) + raster = self.jr.render_at(raster, state.player_car.position.x, 105, player_mask) + + wall_top_mask = self.SHAPE_MASKS["wall_top"] + raster = self.jr.render_at(raster, 0, 0, wall_top_mask) + + wall_bottom_mask = self.SHAPE_MASKS["wall_bottom"] + raster = self.jr.render_at(raster, 0, 210 - wall_bottom_mask.shape[0], wall_bottom_mask) return self.jr.render_from_palette(raster, self.PALETTE) \ No newline at end of file From 33414c80834018009597c000a6c746e95b4e051e Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sun, 16 Nov 2025 14:57:43 +0100 Subject: [PATCH 44/76] add first movment of player and map to game --- src/jaxatari/games/jax_upndown.py | 41 +++++++++++++++++++++---------- 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index e1ed585b7..5d57a07e2 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -20,11 +20,12 @@ class UpNDownConstants(NamedTuple): LANDING_ZONE: int = 15 FIRST_ROAD_LENGTH: int = 4 SECOND_ROAD_LENGTH: int = 4 - FIRST_TRACK_CORNERS_X: chex.Array = jnp.array([20, 50, 80, 100]) #get actual values - FIRST_TRACK_CORNERS_Y: chex.Array = jnp.array([20, 50, 80, 100]) #get actual values + FIRST_TRACK_CORNERS_X: chex.Array = jnp.array([30, 80, 140, 100]) #get actual values + FIRST_TRACK_CORNERS_Y: chex.Array = jnp.array([105, 80, 25, 100]) #get actual values SECOND_TRACK_CORNERS_X: chex.Array = jnp.array([20, 50, 80, 100]) #get actual values SECOND_TRACK_CORNERS_Y: chex.Array = jnp.array([20, 50, 80, 100]) #get actual values PLAYER_SIZE: Tuple[int, int] = (4, 16) + INITIAL_ROAD_POS_Y: int = 25 @@ -51,6 +52,7 @@ class UpNDownState(NamedTuple): is_jumping: chex.Array is_on_road: chex.Array player_car: Car + road_diff: chex.Array @@ -175,7 +177,7 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: ) - is_jumping = jnp.logical_or(jnp.logical_and(state.is_jumping, state.jump_cooldown > 0), jnp.logical_and(state.is_on_road, jnp.logical_and(player_speed > 0, state.jump_cooldown == 0))) + is_jumping = jnp.logical_or(jnp.logical_and(state.is_jumping, state.jump_cooldown > 0), jnp.logical_and(state.is_on_road, jnp.logical_and(player_speed >= 0, jnp.logical_and(state.jump_cooldown == 0, jump)))) jump_cooldown = jax.lax.cond( state.jump_cooldown > 0, lambda s: s - 1, @@ -202,21 +204,24 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: ) + car_direction_x = jax.lax.cond(state.player_car.current_road == 0, + lambda s: self.consts.FIRST_TRACK_CORNERS_X[road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_X[road_index_A], + lambda s: self.consts.SECOND_TRACK_CORNERS_X[road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[road_index_B], + operand=None), car_direction_x = jax.lax.cond( - direction_change, - lambda s: jax.lax.cond(state.player_car.current_road == 0, - lambda s: self.consts.FIRST_TRACK_CORNERS_X[road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_X[road_index_A], - lambda s: self.consts.SECOND_TRACK_CORNERS_X[road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[road_index_B], - operand=None), - lambda s: s, - operand=state.player_car.direction_x, + car_direction_x[0] > 0, + lambda s: 1, + lambda s: -1, + operand=car_direction_x, ) + is_landing = jnp.logical_and(state.jump_cooldown == 1, jump_cooldown == 0) ##calculate new position with speed (TODO: calculate better speed) player_y = state.player_car.position.y + player_speed player_x = state.player_car.position.x + player_speed * car_direction_x + jax.debug.print("Player X: {}, Player Y: {}, car_direction_x: {}", player_x, player_y, car_direction_x) landing_in_Water, between_roads, road_A_x, road_B_x = self._landing_in_water(state, player_x, player_y) landing_in_Water = jnp.logical_and(is_landing, landing_in_Water) @@ -238,6 +243,9 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: ), operand=None, ) + road_diff = state.road_diff + player_speed + jax.debug.print("road_diff: {}", road_diff) + #jax.debug.print("Player X: {}, Player Y: {}, on road: {}, jumping: {}, speed: {}, road index A: {}, road index B: {}, current road: {}", player_x, player_y, is_on_road, is_jumping, player_speed, road_index_A, road_index_B, current_road) return UpNDownState( score=state.score, difficulty=state.difficulty, @@ -258,6 +266,7 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: road_index_B=road_index_B, type=state.player_car.type, ), + road_diff=state.road_diff + player_speed, ) @@ -282,6 +291,7 @@ def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: road_index_B=0, type=0, ), + road_diff=0, ) initial_obs = self._get_observation(state) @@ -371,9 +381,10 @@ def __init__(self, consts: UpNDownConstants = None): background = self._createBackgroundSprite(self.config.game_dimensions) top_block = self._createBackgroundSprite((25, self.config.game_dimensions[1])) bottom_block = self._createBackgroundSprite((16, self.config.game_dimensions[1])) + temp_pointer = self._createBackgroundSprite((1, 1)) # 2. Update asset config to include both walls - asset_config = self._get_asset_config(background, top_block, bottom_block) + asset_config = self._get_asset_config(background, top_block, bottom_block, temp_pointer) sprite_path = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/" # 3. Make a single call to the setup function @@ -393,7 +404,7 @@ def _createBackgroundSprite(self, dimensions: Tuple[int, int]) -> jnp.ndarray: sprite = jnp.tile(jnp.array(color, dtype=jnp.uint8), (*shape[:2], 1)) return sprite - def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.ndarray, bottomBlockSprite: jnp.ndarray) -> list: + def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.ndarray, bottomBlockSprite: jnp.ndarray, tempPointer: jnp.ndarray) -> list: """Returns the declarative manifest of all assets for the game, including both wall sprites.""" return [ {'name': 'background', 'type': 'background', 'data': backgroundSprite}, @@ -401,13 +412,14 @@ def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.n {'name': 'player', 'type': 'single', 'file': 'player_car.npy'}, {'name': 'wall_top', 'type': 'procedural', 'data': topBlockSprite}, {'name': 'wall_bottom', 'type': 'procedural', 'data': bottomBlockSprite}, + {'name': 'tempPointer', 'type': 'procedural', 'data': tempPointer}, ] @partial(jax.jit, static_argnums=(0,)) def render(self, state): raster = self.jr.create_object_raster(self.BACKGROUND) road1_mask = self.SHAPE_MASKS["road1"] - raster = self.jr.render_at(raster, 10, 25, road1_mask) + raster = self.jr.render_at_clipped(raster, 10, self.consts.INITIAL_ROAD_POS_Y + state.road_diff, road1_mask) player_mask = self.SHAPE_MASKS["player"] raster = self.jr.render_at(raster, state.player_car.position.x, 105, player_mask) @@ -418,4 +430,7 @@ def render(self, state): wall_bottom_mask = self.SHAPE_MASKS["wall_bottom"] raster = self.jr.render_at(raster, 0, 210 - wall_bottom_mask.shape[0], wall_bottom_mask) + wall_bottom_mask = self.SHAPE_MASKS["tempPointer"] + raster = self.jr.render_at(raster, 140, 26, wall_bottom_mask) + return self.jr.render_from_palette(raster, self.PALETTE) \ No newline at end of file From 4aee843433b4b941b379ef636bd17af2c7de5841 Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sun, 16 Nov 2025 16:46:33 +0100 Subject: [PATCH 45/76] add new parts of the map --- src/jaxatari/games/jax_upndown.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 5d57a07e2..0e58ecafb 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -395,6 +395,7 @@ def __init__(self, consts: UpNDownConstants = None): self.COLOR_TO_ID, self.FLIP_OFFSETS ) = self.jr.load_and_setup_assets(asset_config, sprite_path) + self.road_sizes = self._get_road_sprite_sizes() def _createBackgroundSprite(self, dimensions: Tuple[int, int]) -> jnp.ndarray: """Creates a procedural background sprite for the game.""" @@ -403,12 +404,24 @@ def _createBackgroundSprite(self, dimensions: Tuple[int, int]) -> jnp.ndarray: shape = (height, width, 4) # Height, Width, RGBA channels sprite = jnp.tile(jnp.array(color, dtype=jnp.uint8), (*shape[:2], 1)) return sprite + + def _get_road_sprite_sizes(self) -> list: + """Returns the sizes of the road sprites.""" + sizes = [] + for file in os.listdir(f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/background/"): + sprite = jnp.load(f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/background/{file}") + sizes.append(sprite.shape[0]) + jax.debug.print("Road sizes: {}", sizes) + return sizes def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.ndarray, bottomBlockSprite: jnp.ndarray, tempPointer: jnp.ndarray) -> list: """Returns the declarative manifest of all assets for the game, including both wall sprites.""" + roads = [] + for x in range(13): + roads.append(f"background/background{x+1}.npy") return [ {'name': 'background', 'type': 'background', 'data': backgroundSprite}, - {'name': 'road1', 'type': 'single', 'file': 'background/background1.npy'}, + {'name': 'road', 'type': 'group', 'files': roads}, {'name': 'player', 'type': 'single', 'file': 'player_car.npy'}, {'name': 'wall_top', 'type': 'procedural', 'data': topBlockSprite}, {'name': 'wall_bottom', 'type': 'procedural', 'data': bottomBlockSprite}, @@ -418,8 +431,14 @@ def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.n @partial(jax.jit, static_argnums=(0,)) def render(self, state): raster = self.jr.create_object_raster(self.BACKGROUND) - road1_mask = self.SHAPE_MASKS["road1"] + + road1_mask = self.SHAPE_MASKS["road"][0] raster = self.jr.render_at_clipped(raster, 10, self.consts.INITIAL_ROAD_POS_Y + state.road_diff, road1_mask) + diff = 0 + for i in range(12): + road1_mask = self.SHAPE_MASKS["road"][i+1] + diff += self.road_sizes[i+1] + raster = self.jr.render_at_clipped(raster, 10, self.consts.INITIAL_ROAD_POS_Y + state.road_diff - diff, road1_mask) player_mask = self.SHAPE_MASKS["player"] raster = self.jr.render_at(raster, state.player_car.position.x, 105, player_mask) From 5db04ae6ec0925349c04b500d2f7e96e26fa4a74 Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sun, 23 Nov 2025 21:10:27 +0100 Subject: [PATCH 46/76] changes to player movement and road selection --- src/jaxatari/games/jax_upndown.py | 223 ++++++++++++++++++++---------- 1 file changed, 149 insertions(+), 74 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 0e58ecafb..6651a57b2 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -15,13 +15,13 @@ class UpNDownConstants(NamedTuple): FRAME_SKIP: int = 4 DIFFICULTIES: chex.Array = jnp.array([0, 1, 2, 3, 4, 5]) ACTION_REPEAT_PROBS: float = 0.25 - MAX_SPEED: int = 4 + MAX_SPEED: int = 1 JUMP_FRAMES: int = 10 LANDING_ZONE: int = 15 FIRST_ROAD_LENGTH: int = 4 SECOND_ROAD_LENGTH: int = 4 - FIRST_TRACK_CORNERS_X: chex.Array = jnp.array([30, 80, 140, 100]) #get actual values - FIRST_TRACK_CORNERS_Y: chex.Array = jnp.array([105, 80, 25, 100]) #get actual values + FIRST_TRACK_CORNERS_X: chex.Array = jnp.array([30, 80, 140, 80]) #get actual values + FIRST_TRACK_CORNERS_Y: chex.Array = jnp.array([105, 80, 25, 0]) #get actual values SECOND_TRACK_CORNERS_X: chex.Array = jnp.array([20, 50, 80, 100]) #get actual values SECOND_TRACK_CORNERS_Y: chex.Array = jnp.array([20, 50, 80, 100]) #get actual values PLAYER_SIZE: Tuple[int, int] = (4, 16) @@ -52,7 +52,7 @@ class UpNDownState(NamedTuple): is_jumping: chex.Array is_on_road: chex.Array player_car: Car - road_diff: chex.Array + step_counter: chex.Array @@ -89,59 +89,29 @@ def __init__(self, consts: UpNDownConstants = None, reward_funcs: list[callable] self.obs_size = 3*4+1+1 @partial(jax.jit, static_argnums=(0,)) - def _car_past_corner(self, car: Car, state: UpNDownState) -> chex.Array: - direction_change_A = jnp.logical_or(jnp.logical_and(car.speed > 0, car.position.y + car.speed > self.consts.FIRST_TRACK_CORNERS_Y[car.road_index_A+1]), jnp.logical_and(car.speed < 0, car.position.y + car.speed < self.consts.FIRST_TRACK_CORNERS_Y[car.road_index_A])) - direction_change_B = jnp.logical_or(jnp.logical_and(car.speed > 0, car.position.y + car.speed > self.consts.SECOND_TRACK_CORNERS_Y[car.road_index_B+1]), jnp.logical_and(car.speed < 0, car.position.y + car.speed < self.consts.SECOND_TRACK_CORNERS_Y[car.road_index_B])) - - road_index_A = jax.lax.cond(jnp.logical_and(direction_change_A, car.speed > 0), - lambda s: s + 1, - lambda s: s, - operand=car.road_index_A, - ) - road_index_A = jax.lax.cond(jnp.logical_and(direction_change_A, car.speed < 0), - lambda s: s - 1, - lambda s: s, - operand=car.road_index_A, - ) - - road_index_B = jax.lax.cond(jnp.logical_and(direction_change_B, car.speed > 0), - lambda s: s + 1, - lambda s: s, - operand=car.road_index_B, - ) - road_index_B = jax.lax.cond(jnp.logical_and(direction_change_B, car.speed < 0), - lambda s: s - 1, - lambda s: s, - operand=car.road_index_B, - ) - current_road_length_A = self.consts.FIRST_ROAD_LENGTH - current_road_length_B = self.consts.SECOND_ROAD_LENGTH - - road_index_A = jax.lax.cond(road_index_A < 0, - lambda s: current_road_length_A - 1, - lambda s: s, - operand=road_index_A, - ) - - road_index_A = jax.lax.cond(road_index_A >= current_road_length_A, - lambda s: 0, - lambda s: s, - operand=road_index_A, - ) - - road_index_B = jax.lax.cond(road_index_B < 0, - lambda s: current_road_length_B - 1, - lambda s: s, - operand=road_index_B, - ) - - road_index_B = jax.lax.cond(road_index_B >= current_road_length_B, - lambda s: 0, - lambda s: s, - operand=road_index_B, + def _getSlopeAndB(self, state: UpNDownState) -> chex.Array: + trackx, tracky, roadIndex = jax.lax.cond( + state.player_car.current_road == 0, + lambda s: (self.consts.FIRST_TRACK_CORNERS_X, self.consts.FIRST_TRACK_CORNERS_Y, state.player_car.road_index_A), + lambda s: (self.consts.SECOND_TRACK_CORNERS_X, self.consts.SECOND_TRACK_CORNERS_Y, state.player_car.road_index_B), + operand=None,) + slope = jax.lax.cond( + trackx[roadIndex+1] - trackx[roadIndex] != 0, + lambda s: (tracky[roadIndex+1] - tracky[roadIndex]) / (trackx[roadIndex+1] - trackx[roadIndex]), + lambda s: jnp.inf, + operand=None, ) + b = tracky[roadIndex] - slope * trackx[roadIndex] + return slope, b + + @partial(jax.jit, static_argnums=(0,)) + def _isOnLine(self, state: UpNDownState, new_position_x: chex.Array, new_position_y: chex.Array, player_speed: chex.Array) -> chex.Array: + slope, b = self._getSlopeAndB(state) + jax.debug.print("slope: {}, b: {}", slope, b) + isOnLine = jnp.less_equal(jnp.abs(jnp.round(jnp.subtract(new_position_y, slope * new_position_x + b))), player_speed) - return road_index_A, road_index_B + jax.debug.print("isOnLine: {}", jnp.subtract(new_position_y, slope * new_position_x + b)) + return isOnLine @partial(jax.jit, static_argnums=(0,)) def _landing_in_water(self, state: UpNDownState, new_position_x: chex.Array, new_position_y: chex.Array) -> chex.Array: @@ -219,8 +189,39 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: is_landing = jnp.logical_and(state.jump_cooldown == 1, jump_cooldown == 0) ##calculate new position with speed (TODO: calculate better speed) - player_y = state.player_car.position.y + player_speed - player_x = state.player_car.position.x + player_speed * car_direction_x + player_y = jax.lax.cond( + state.step_counter % 8 == 4, + lambda s: jax.lax.cond( + is_jumping, + lambda s: state.player_car.position.y + player_speed * -1, + lambda s: jax.lax.cond( + self._isOnLine(state, state.player_car.position.x, s + player_speed * -1, player_speed), + lambda s: s + player_speed * -1, + lambda s: s, + operand=state.player_car.position.y, + ), + operand=state.player_car.position.y), + lambda s: state.player_car.position.y, + operand=None, + ) + player_x = jax.lax.cond( + state.step_counter % 8 == 0, + lambda s: jax.lax.cond( + is_jumping, + lambda s: s + player_speed * car_direction_x, + lambda s: jax.lax.cond( + self._isOnLine(state, s + player_speed * car_direction_x, player_y, player_speed), + lambda s: s + player_speed * car_direction_x, + lambda s: s, + operand=state.player_car.position.x, + ), + operand=state.player_car.position.x), + lambda s: s, + operand=state.player_car.position.x, + ) + + ##if y not on mx +b then no move + jax.debug.print("Player X: {}, Player Y: {}, car_direction_x: {}", player_x, player_y, car_direction_x) landing_in_Water, between_roads, road_A_x, road_B_x = self._landing_in_water(state, player_x, player_y) @@ -243,8 +244,61 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: ), operand=None, ) - road_diff = state.road_diff + player_speed - jax.debug.print("road_diff: {}", road_diff) + + road_index_A = jax.lax.cond( + current_road == 2, + lambda s: road_index_A, + lambda s: jax.lax.cond( + self.consts.FIRST_TRACK_CORNERS_Y[road_index_A] < player_y, + lambda s: road_index_A - 1, + lambda s: jax.lax.cond( + len(self.consts.FIRST_TRACK_CORNERS_Y) == road_index_A + 1, + lambda s: jax.lax.cond( + self.consts.FIRST_TRACK_CORNERS_Y[0] > player_y, + lambda s: 0, + lambda s: road_index_A, + operand=None, + ), + lambda s: jax.lax.cond( + self.consts.FIRST_TRACK_CORNERS_Y[road_index_A+1] > player_y, + lambda s: road_index_A + 1, + lambda s: road_index_A, + operand=None, + ), + operand=None, + ), + operand=None, + ), + operand=None, + ) + + road_index_B = jax.lax.cond( + current_road == 2, + lambda s: road_index_B, + lambda s: jax.lax.cond( + self.consts.SECOND_TRACK_CORNERS_Y[road_index_B] < player_y, + lambda s: road_index_B - 1, + lambda s: jax.lax.cond( + len(self.consts.SECOND_TRACK_CORNERS_Y) == road_index_B + 1, + lambda s: jax.lax.cond( + self.consts.SECOND_TRACK_CORNERS_Y[0] > player_y, + lambda s: 0, + lambda s: road_index_B, + operand=None, + ), + lambda s: jax.lax.cond( + self.consts.SECOND_TRACK_CORNERS_Y[road_index_B+1] > player_y, + lambda s: road_index_B + 1, + lambda s: road_index_B, + operand=None, + ), + operand=None, + ), + operand=None, + ), + operand=None, + ) + #jax.debug.print("Player X: {}, Player Y: {}, on road: {}, jumping: {}, speed: {}, road index A: {}, road index B: {}, current road: {}", player_x, player_y, is_on_road, is_jumping, player_speed, road_index_A, road_index_B, current_road) return UpNDownState( score=state.score, @@ -266,7 +320,7 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: road_index_B=road_index_B, type=state.player_car.type, ), - road_diff=state.road_diff + player_speed, + step_counter=state.step_counter + 1, ) @@ -291,10 +345,9 @@ def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: road_index_B=0, type=0, ), - road_diff=0, + step_counter=jnp.array(0), ) initial_obs = self._get_observation(state) - return initial_obs, state @partial(jax.jit, static_argnums=(0,)) @@ -316,7 +369,7 @@ def render(self, state: UpNDownState) -> jnp.ndarray: def _get_observation(self, state: UpNDownState): player = EntityPosition( x=jnp.array(state.player_car.position.x), - y=state.player_car.position.y, + y=jnp.array(state.player_car.position.y), width=jnp.array(self.consts.PLAYER_SIZE[0]), height=jnp.array(self.consts.PLAYER_SIZE[1]), ) @@ -395,7 +448,7 @@ def __init__(self, consts: UpNDownConstants = None): self.COLOR_TO_ID, self.FLIP_OFFSETS ) = self.jr.load_and_setup_assets(asset_config, sprite_path) - self.road_sizes = self._get_road_sprite_sizes() + self.road_sizes, self.complete_road_size = self._get_road_sprite_sizes() def _createBackgroundSprite(self, dimensions: Tuple[int, int]) -> jnp.ndarray: """Creates a procedural background sprite for the game.""" @@ -408,11 +461,13 @@ def _createBackgroundSprite(self, dimensions: Tuple[int, int]) -> jnp.ndarray: def _get_road_sprite_sizes(self) -> list: """Returns the sizes of the road sprites.""" sizes = [] + complete_size = 0 for file in os.listdir(f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/background/"): sprite = jnp.load(f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/background/{file}") sizes.append(sprite.shape[0]) - jax.debug.print("Road sizes: {}", sizes) - return sizes + if file != "background1.npy": + complete_size += sprite.shape[0] + return sizes, complete_size def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.ndarray, bottomBlockSprite: jnp.ndarray, tempPointer: jnp.ndarray) -> list: """Returns the declarative manifest of all assets for the game, including both wall sprites.""" @@ -431,14 +486,34 @@ def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.n @partial(jax.jit, static_argnums=(0,)) def render(self, state): raster = self.jr.create_object_raster(self.BACKGROUND) + road_diff = (-state.player_car.position.y + 105) % self.complete_road_size + + # Vectorized road rendering: compute all Y offsets, stamp via vmap, fold overlays. + road_masks = self.SHAPE_MASKS["road"] # shape: (N, H, W) + num_segments = road_masks.shape[0] + + sizes = jnp.array(self.road_sizes, dtype=jnp.int32) + # Offsets: [0, cumsum(sizes[1:])] + offsets = jnp.concatenate([ + jnp.array([0], dtype=jnp.int32), + jnp.cumsum(sizes[1:], axis=0) + ], axis=0) + + base_y = jnp.asarray(self.consts.INITIAL_ROAD_POS_Y, dtype=jnp.int32) + y_positions = base_y + (road_diff.astype(jnp.int32)) - offsets + + empty_raster = jnp.full_like(self.BACKGROUND, self.jr.TRANSPARENT_ID) + + def stamp(y, mask): + return self.jr.render_at_clipped(empty_raster, 10, y, mask) + + overlays = jax.vmap(stamp)(y_positions, road_masks) + + def combine(i, acc): + over = overlays[i] + return jnp.where(over != self.jr.TRANSPARENT_ID, over, acc) - road1_mask = self.SHAPE_MASKS["road"][0] - raster = self.jr.render_at_clipped(raster, 10, self.consts.INITIAL_ROAD_POS_Y + state.road_diff, road1_mask) - diff = 0 - for i in range(12): - road1_mask = self.SHAPE_MASKS["road"][i+1] - diff += self.road_sizes[i+1] - raster = self.jr.render_at_clipped(raster, 10, self.consts.INITIAL_ROAD_POS_Y + state.road_diff - diff, road1_mask) + raster = jax.lax.fori_loop(0, num_segments, combine, raster) player_mask = self.SHAPE_MASKS["player"] raster = self.jr.render_at(raster, state.player_car.position.x, 105, player_mask) @@ -450,6 +525,6 @@ def render(self, state): raster = self.jr.render_at(raster, 0, 210 - wall_bottom_mask.shape[0], wall_bottom_mask) wall_bottom_mask = self.SHAPE_MASKS["tempPointer"] - raster = self.jr.render_at(raster, 140, 26, wall_bottom_mask) + raster = self.jr.render_at(raster, 140, 25, wall_bottom_mask) return self.jr.render_from_palette(raster, self.PALETTE) \ No newline at end of file From 11dfddcb6ab438e1abadb4b77d3c50d0b2e864c6 Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sat, 29 Nov 2025 19:26:08 +0100 Subject: [PATCH 47/76] add logic for different speeds --- src/jaxatari/games/jax_upndown.py | 72 +++++++++++++++++++++---------- 1 file changed, 49 insertions(+), 23 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 6651a57b2..d18f797b0 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -1,5 +1,6 @@ from jax._src.pjit import JitWrapped import os +import math from functools import partial from typing import NamedTuple, Tuple import jax.lax @@ -15,7 +16,7 @@ class UpNDownConstants(NamedTuple): FRAME_SKIP: int = 4 DIFFICULTIES: chex.Array = jnp.array([0, 1, 2, 3, 4, 5]) ACTION_REPEAT_PROBS: float = 0.25 - MAX_SPEED: int = 1 + MAX_SPEED: int = 4 JUMP_FRAMES: int = 10 LANDING_ZONE: int = 15 FIRST_ROAD_LENGTH: int = 4 @@ -145,6 +146,8 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: lambda s: s, operand=player_speed, ) + dividers = jnp.array([0, 1, 2, 4, 8]) + speed_divider = dividers[jnp.abs(player_speed)] is_jumping = jnp.logical_or(jnp.logical_and(state.is_jumping, state.jump_cooldown > 0), jnp.logical_and(state.is_on_road, jnp.logical_and(player_speed >= 0, jnp.logical_and(state.jump_cooldown == 0, jump)))) @@ -164,15 +167,14 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: ##check if player is on the the road is_on_road = ~state.is_jumping - road_index_A, road_index_B = self._car_past_corner(state.player_car, state) - - direction_change = jax.lax.cond( + '''direction_change = jax.lax.cond( jnp.logical_and(is_on_road, jnp.logical_or(jnp.logical_and(jnp.equal(road_index_A, state.player_car.road_index_A) , state.player_car.current_road == 0), (jnp.logical_and(jnp.equal(road_index_B, state.player_car.road_index_B) , state.player_car.current_road == 1)))) , lambda s: False, lambda s: True, operand=None, - ) - + )''' + road_index_A = state.player_car.road_index_A + road_index_B = state.player_car.road_index_B car_direction_x = jax.lax.cond(state.player_car.current_road == 0, lambda s: self.consts.FIRST_TRACK_CORNERS_X[road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_X[road_index_A], @@ -190,33 +192,33 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: ##calculate new position with speed (TODO: calculate better speed) player_y = jax.lax.cond( - state.step_counter % 8 == 4, + state.step_counter % (16/ speed_divider) == 8 / speed_divider, lambda s: jax.lax.cond( is_jumping, - lambda s: state.player_car.position.y + player_speed * -1, + lambda s: state.player_car.position.y + jax.lax.abs(player_speed) / player_speed * -1, lambda s: jax.lax.cond( - self._isOnLine(state, state.player_car.position.x, s + player_speed * -1, player_speed), - lambda s: s + player_speed * -1, - lambda s: s, + self._isOnLine(state, state.player_car.position.x, s + jax.lax.abs(player_speed) / player_speed * -1, 1), + lambda s: s + jax.lax.abs(player_speed) / player_speed * -1, + lambda s: jnp.array(s, float), operand=state.player_car.position.y, ), operand=state.player_car.position.y), - lambda s: state.player_car.position.y, - operand=None, + lambda s: jnp.array(s, float), + operand=state.player_car.position.y, ) player_x = jax.lax.cond( - state.step_counter % 8 == 0, + state.step_counter % (16/ speed_divider) == 0, lambda s: jax.lax.cond( is_jumping, - lambda s: s + player_speed * car_direction_x, + lambda s: s + jax.lax.abs(player_speed) / player_speed * car_direction_x, lambda s: jax.lax.cond( - self._isOnLine(state, s + player_speed * car_direction_x, player_y, player_speed), - lambda s: s + player_speed * car_direction_x, - lambda s: s, + self._isOnLine(state, s + jax.lax.abs(player_speed) / player_speed * car_direction_x, player_y, 1), + lambda s: s + jax.lax.abs(player_speed) / player_speed * car_direction_x, + lambda s: jnp.array(s, float), operand=state.player_car.position.x, ), operand=state.player_car.position.x), - lambda s: s, + lambda s: jnp.array(s, float), operand=state.player_car.position.x, ) @@ -449,6 +451,12 @@ def __init__(self, consts: UpNDownConstants = None): self.FLIP_OFFSETS ) = self.jr.load_and_setup_assets(asset_config, sprite_path) self.road_sizes, self.complete_road_size = self._get_road_sprite_sizes() + self.view_height = self.config.game_dimensions[0] + # Precompute offsets so repeated road tiles can wrap seamlessly without gaps. + road_cycle = max(1, self.complete_road_size) + repeats = max(1, math.ceil(self.view_height / road_cycle) + 2) + self._road_tile_offsets = jnp.arange(-repeats, repeats + 1, dtype=jnp.int32) * jnp.int32(self.complete_road_size) + self._num_road_tiles = int(self._road_tile_offsets.shape[0]) def _createBackgroundSprite(self, dimensions: Tuple[int, int]) -> jnp.ndarray: """Creates a procedural background sprite for the game.""" @@ -502,18 +510,36 @@ def render(self, state): base_y = jnp.asarray(self.consts.INITIAL_ROAD_POS_Y, dtype=jnp.int32) y_positions = base_y + (road_diff.astype(jnp.int32)) - offsets + tile_offsets = self._road_tile_offsets + tile_count = self._num_road_tiles + tiled_y = (y_positions[None, :] + tile_offsets[:, None]).reshape(tile_count * num_segments) + tiled_masks = jnp.tile(road_masks, (tile_count, 1, 1)) + tiled_sizes = jnp.tile(sizes, tile_count) + + visible = jnp.logical_and( + tiled_y < self.view_height, + (tiled_y + tiled_sizes) > 0 + ) + empty_raster = jnp.full_like(self.BACKGROUND, self.jr.TRANSPARENT_ID) - def stamp(y, mask): - return self.jr.render_at_clipped(empty_raster, 10, y, mask) + def stamp(y, mask, is_visible): + return jax.lax.cond( + is_visible, + lambda _: self.jr.render_at_clipped(empty_raster, 10, y, mask), + lambda _: empty_raster, + operand=None, + ) + + overlays = jax.vmap(stamp)(tiled_y, tiled_masks, visible) - overlays = jax.vmap(stamp)(y_positions, road_masks) + total_segments = tile_count * num_segments def combine(i, acc): over = overlays[i] return jnp.where(over != self.jr.TRANSPARENT_ID, over, acc) - raster = jax.lax.fori_loop(0, num_segments, combine, raster) + raster = jax.lax.fori_loop(0, total_segments, combine, raster) player_mask = self.SHAPE_MASKS["player"] raster = self.jr.render_at(raster, state.player_car.position.x, 105, player_mask) From 927648d4b80be5e5596241beb6c4c10fc1458b2e Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Tue, 2 Dec 2025 23:49:59 +0100 Subject: [PATCH 48/76] car now follows road with loop --- src/jaxatari/games/jax_upndown.py | 69 ++++++++++++++++++++----------- 1 file changed, 46 insertions(+), 23 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index d18f797b0..9e30ea7cd 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -21,10 +21,10 @@ class UpNDownConstants(NamedTuple): LANDING_ZONE: int = 15 FIRST_ROAD_LENGTH: int = 4 SECOND_ROAD_LENGTH: int = 4 - FIRST_TRACK_CORNERS_X: chex.Array = jnp.array([30, 80, 140, 80]) #get actual values - FIRST_TRACK_CORNERS_Y: chex.Array = jnp.array([105, 80, 25, 0]) #get actual values - SECOND_TRACK_CORNERS_X: chex.Array = jnp.array([20, 50, 80, 100]) #get actual values - SECOND_TRACK_CORNERS_Y: chex.Array = jnp.array([20, 50, 80, 100]) #get actual values + FIRST_TRACK_CORNERS_X: chex.Array = jnp.array([30, 75, 128, 75, 21, 75, 131, 111, 150, 95, 150, 115, 150, 108, 150, 115, 115, 75, 18, 67, 38, 38, 20, 64, 30]) #get actual values + FIRST_TRACK_CORNERS_Y: chex.Array = jnp.array([105, 65, 7, -50, -98, -163, -222, -242, -277, -362, -420, -460, -492, -520, -565, -600, -633, -683, -733, -793, -820, -845, -867, -895, -928]) #get actual values + SECOND_TRACK_CORNERS_X: chex.Array = FIRST_TRACK_CORNERS_X#jnp.array([20, 50]) #get actual values + SECOND_TRACK_CORNERS_Y: chex.Array = FIRST_TRACK_CORNERS_Y#jnp.array([20, 50, ]) #get actual values PLAYER_SIZE: Tuple[int, int] = (4, 16) INITIAL_ROAD_POS_Y: int = 25 @@ -54,6 +54,7 @@ class UpNDownState(NamedTuple): is_on_road: chex.Array player_car: Car step_counter: chex.Array + road_reset: chex.Array @@ -99,7 +100,7 @@ def _getSlopeAndB(self, state: UpNDownState) -> chex.Array: slope = jax.lax.cond( trackx[roadIndex+1] - trackx[roadIndex] != 0, lambda s: (tracky[roadIndex+1] - tracky[roadIndex]) / (trackx[roadIndex+1] - trackx[roadIndex]), - lambda s: jnp.inf, + lambda s: 300.0, operand=None, ) b = tracky[roadIndex] - slope * trackx[roadIndex] @@ -109,7 +110,7 @@ def _getSlopeAndB(self, state: UpNDownState) -> chex.Array: def _isOnLine(self, state: UpNDownState, new_position_x: chex.Array, new_position_y: chex.Array, player_speed: chex.Array) -> chex.Array: slope, b = self._getSlopeAndB(state) jax.debug.print("slope: {}, b: {}", slope, b) - isOnLine = jnp.less_equal(jnp.abs(jnp.round(jnp.subtract(new_position_y, slope * new_position_x + b))), player_speed) + isOnLine = jnp.logical_or(jnp.logical_and(jnp.equal(slope, 300.0), jnp.equal(new_position_x, state.player_car.position.x)), jnp.less_equal(jnp.abs(jnp.round(jnp.subtract(new_position_y, slope * new_position_x + b))), player_speed)) jax.debug.print("isOnLine: {}", jnp.subtract(new_position_y, slope * new_position_x + b)) return isOnLine @@ -301,6 +302,22 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: operand=None, ) + player_y = jax.lax.cond( + state.road_reset, + lambda s: 105.0, + lambda s: s, + operand=player_y, + ) + + road_reset = jax.lax.cond( + jnp.equal(player_y, -928), + lambda s: True, + lambda s: False, + operand=None, + ) + + + #jax.debug.print("Player X: {}, Player Y: {}, on road: {}, jumping: {}, speed: {}, road index A: {}, road index B: {}, current road: {}", player_x, player_y, is_on_road, is_jumping, player_speed, road_index_A, road_index_B, current_road) return UpNDownState( score=state.score, @@ -308,6 +325,7 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: jump_cooldown=jump_cooldown, is_jumping=is_jumping, is_on_road=is_on_road, + road_reset=road_reset, player_car=Car( position=EntityPosition( x=player_x, @@ -333,10 +351,11 @@ def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: jump_cooldown=0, is_jumping=False, is_on_road=True, + road_reset=False, player_car=Car( position=EntityPosition( x=30, - y=105, + y= 105, width=self.consts.PLAYER_SIZE[0], height=self.consts.PLAYER_SIZE[1], ), @@ -439,7 +458,7 @@ def __init__(self, consts: UpNDownConstants = None): temp_pointer = self._createBackgroundSprite((1, 1)) # 2. Update asset config to include both walls - asset_config = self._get_asset_config(background, top_block, bottom_block, temp_pointer) + asset_config, road_files = self._get_asset_config(background, top_block, bottom_block, temp_pointer) sprite_path = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/" # 3. Make a single call to the setup function @@ -450,7 +469,7 @@ def __init__(self, consts: UpNDownConstants = None): self.COLOR_TO_ID, self.FLIP_OFFSETS ) = self.jr.load_and_setup_assets(asset_config, sprite_path) - self.road_sizes, self.complete_road_size = self._get_road_sprite_sizes() + self.road_sizes, self.complete_road_size = self._get_road_sprite_sizes(road_files) self.view_height = self.config.game_dimensions[0] # Precompute offsets so repeated road tiles can wrap seamlessly without gaps. road_cycle = max(1, self.complete_road_size) @@ -466,22 +485,26 @@ def _createBackgroundSprite(self, dimensions: Tuple[int, int]) -> jnp.ndarray: sprite = jnp.tile(jnp.array(color, dtype=jnp.uint8), (*shape[:2], 1)) return sprite - def _get_road_sprite_sizes(self) -> list: - """Returns the sizes of the road sprites.""" + def _get_road_sprite_sizes(self, road_files: list[str]) -> list: + """Returns the sizes of the road sprites limited to the configured files.""" + road_dir = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/roads" sizes = [] - complete_size = 0 - for file in os.listdir(f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/background/"): - sprite = jnp.load(f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/background/{file}") + for file in road_files: + sprite_name = os.path.basename(file) + sprite = jnp.load(f"{road_dir}/{sprite_name}") sizes.append(sprite.shape[0]) - if file != "background1.npy": - complete_size += sprite.shape[0] + complete_size = int(sum(sizes)) + jax.debug.print("Complete road size: {}", complete_size) return sizes, complete_size - def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.ndarray, bottomBlockSprite: jnp.ndarray, tempPointer: jnp.ndarray) -> list: - """Returns the declarative manifest of all assets for the game, including both wall sprites.""" - roads = [] - for x in range(13): - roads.append(f"background/background{x+1}.npy") + def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.ndarray, bottomBlockSprite: jnp.ndarray, tempPointer: jnp.ndarray) -> tuple[list, list[str]]: + """Returns the asset manifest and ordered road files.""" + road_dir = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/roads" + road_files = sorted( + file for file in os.listdir(road_dir) + if file.endswith(".npy") + ) + roads = [f"roads/{file}" for file in road_files] return [ {'name': 'background', 'type': 'background', 'data': backgroundSprite}, {'name': 'road', 'type': 'group', 'files': roads}, @@ -489,7 +512,7 @@ def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.n {'name': 'wall_top', 'type': 'procedural', 'data': topBlockSprite}, {'name': 'wall_bottom', 'type': 'procedural', 'data': bottomBlockSprite}, {'name': 'tempPointer', 'type': 'procedural', 'data': tempPointer}, - ] + ], roads @partial(jax.jit, static_argnums=(0,)) def render(self, state): @@ -512,7 +535,7 @@ def render(self, state): tile_offsets = self._road_tile_offsets tile_count = self._num_road_tiles - tiled_y = (y_positions[None, :] + tile_offsets[:, None]).reshape(tile_count * num_segments) + tiled_y = (y_positions[None, :] + tile_offsets[:, None]).reshape(-1) tiled_masks = jnp.tile(road_masks, (tile_count, 1, 1)) tiled_sizes = jnp.tile(sizes, tile_count) From e74be01e3a1e41a22d781892c0d87d38d2aceea5 Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Wed, 3 Dec 2025 00:00:17 +0100 Subject: [PATCH 49/76] remove offset, add moving backwards --- src/jaxatari/games/jax_upndown.py | 24 ++++-------------------- 1 file changed, 4 insertions(+), 20 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 9e30ea7cd..d86c3cc00 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -22,7 +22,7 @@ class UpNDownConstants(NamedTuple): FIRST_ROAD_LENGTH: int = 4 SECOND_ROAD_LENGTH: int = 4 FIRST_TRACK_CORNERS_X: chex.Array = jnp.array([30, 75, 128, 75, 21, 75, 131, 111, 150, 95, 150, 115, 150, 108, 150, 115, 115, 75, 18, 67, 38, 38, 20, 64, 30]) #get actual values - FIRST_TRACK_CORNERS_Y: chex.Array = jnp.array([105, 65, 7, -50, -98, -163, -222, -242, -277, -362, -420, -460, -492, -520, -565, -600, -633, -683, -733, -793, -820, -845, -867, -895, -928]) #get actual values + FIRST_TRACK_CORNERS_Y: chex.Array = jnp.array([0, -40, -98, -155, -203, -268, -327, -347, -382, -467, -525, -565, -597, -625, -670, -705, -738, -788, -838, -898, -925, -950, -972, -1000, -1033]) #get actual values SECOND_TRACK_CORNERS_X: chex.Array = FIRST_TRACK_CORNERS_X#jnp.array([20, 50]) #get actual values SECOND_TRACK_CORNERS_Y: chex.Array = FIRST_TRACK_CORNERS_Y#jnp.array([20, 50, ]) #get actual values PLAYER_SIZE: Tuple[int, int] = (4, 16) @@ -54,7 +54,6 @@ class UpNDownState(NamedTuple): is_on_road: chex.Array player_car: Car step_counter: chex.Array - road_reset: chex.Array @@ -302,19 +301,6 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: operand=None, ) - player_y = jax.lax.cond( - state.road_reset, - lambda s: 105.0, - lambda s: s, - operand=player_y, - ) - - road_reset = jax.lax.cond( - jnp.equal(player_y, -928), - lambda s: True, - lambda s: False, - operand=None, - ) @@ -325,11 +311,10 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: jump_cooldown=jump_cooldown, is_jumping=is_jumping, is_on_road=is_on_road, - road_reset=road_reset, player_car=Car( position=EntityPosition( x=player_x, - y=player_y, + y=-((player_y * -1) % 1036), width=state.player_car.position.width, height=state.player_car.position.height, ), @@ -351,11 +336,10 @@ def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: jump_cooldown=0, is_jumping=False, is_on_road=True, - road_reset=False, player_car=Car( position=EntityPosition( x=30, - y= 105, + y= 0, width=self.consts.PLAYER_SIZE[0], height=self.consts.PLAYER_SIZE[1], ), @@ -517,7 +501,7 @@ def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.n @partial(jax.jit, static_argnums=(0,)) def render(self, state): raster = self.jr.create_object_raster(self.BACKGROUND) - road_diff = (-state.player_car.position.y + 105) % self.complete_road_size + road_diff = (-state.player_car.position.y) % self.complete_road_size # Vectorized road rendering: compute all Y offsets, stamp via vmap, fold overlays. road_masks = self.SHAPE_MASKS["road"] # shape: (N, H, W) From 2ee59e5a651237b0a44e0137ddf674dab671f9d1 Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Fri, 5 Dec 2025 21:45:35 +0100 Subject: [PATCH 50/76] add second road --- src/jaxatari/games/jax_upndown.py | 61 +++++++++++++++---------------- 1 file changed, 30 insertions(+), 31 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index d86c3cc00..a75917186 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -21,12 +21,11 @@ class UpNDownConstants(NamedTuple): LANDING_ZONE: int = 15 FIRST_ROAD_LENGTH: int = 4 SECOND_ROAD_LENGTH: int = 4 - FIRST_TRACK_CORNERS_X: chex.Array = jnp.array([30, 75, 128, 75, 21, 75, 131, 111, 150, 95, 150, 115, 150, 108, 150, 115, 115, 75, 18, 67, 38, 38, 20, 64, 30]) #get actual values - FIRST_TRACK_CORNERS_Y: chex.Array = jnp.array([0, -40, -98, -155, -203, -268, -327, -347, -382, -467, -525, -565, -597, -625, -670, -705, -738, -788, -838, -898, -925, -950, -972, -1000, -1033]) #get actual values - SECOND_TRACK_CORNERS_X: chex.Array = FIRST_TRACK_CORNERS_X#jnp.array([20, 50]) #get actual values - SECOND_TRACK_CORNERS_Y: chex.Array = FIRST_TRACK_CORNERS_Y#jnp.array([20, 50, ]) #get actual values + FIRST_TRACK_CORNERS_X: chex.Array = jnp.array([30, 75, 128, 75, 21, 75, 131, 111, 150, 95, 150, 115, 150, 108, 150, 115, 115, 75, 18, 38, 67, 38, 38, 20, 64, 30]) + TRACK_CORNERS_Y: chex.Array = jnp.array([0, -40, -98, -155, -203, -268, -327, -347, -382, -467, -525, -565, -597, -625, -670, -705, -738, -788, -838, -862, -898, -925, -950, -972, -1000, -1036]) + SECOND_TRACK_CORNERS_X: chex.Array = jnp.array([115, 75, 20, 75, 133, 75, 22, 37, 63, 27, 66, 30, 63, 24, 60, 38, 38, 75, 131, 111, 150, 118, 118, 98, 150, 115]) PLAYER_SIZE: Tuple[int, int] = (4, 16) - INITIAL_ROAD_POS_Y: int = 25 + INITIAL_ROAD_POS_Y: int = 25 @@ -93,8 +92,8 @@ def __init__(self, consts: UpNDownConstants = None, reward_funcs: list[callable] def _getSlopeAndB(self, state: UpNDownState) -> chex.Array: trackx, tracky, roadIndex = jax.lax.cond( state.player_car.current_road == 0, - lambda s: (self.consts.FIRST_TRACK_CORNERS_X, self.consts.FIRST_TRACK_CORNERS_Y, state.player_car.road_index_A), - lambda s: (self.consts.SECOND_TRACK_CORNERS_X, self.consts.SECOND_TRACK_CORNERS_Y, state.player_car.road_index_B), + lambda s: (self.consts.FIRST_TRACK_CORNERS_X, self.consts.TRACK_CORNERS_Y, state.player_car.road_index_A), + lambda s: (self.consts.SECOND_TRACK_CORNERS_X, self.consts.TRACK_CORNERS_Y, state.player_car.road_index_B), operand=None,) slope = jax.lax.cond( trackx[roadIndex+1] - trackx[roadIndex] != 0, @@ -106,18 +105,20 @@ def _getSlopeAndB(self, state: UpNDownState) -> chex.Array: return slope, b @partial(jax.jit, static_argnums=(0,)) - def _isOnLine(self, state: UpNDownState, new_position_x: chex.Array, new_position_y: chex.Array, player_speed: chex.Array) -> chex.Array: + def _isOnLine(self, state: UpNDownState, player_speed: chex.Array, turn: chex.Array) -> chex.Array: slope, b = self._getSlopeAndB(state) - jax.debug.print("slope: {}, b: {}", slope, b) - isOnLine = jnp.logical_or(jnp.logical_and(jnp.equal(slope, 300.0), jnp.equal(new_position_x, state.player_car.position.x)), jnp.less_equal(jnp.abs(jnp.round(jnp.subtract(new_position_y, slope * new_position_x + b))), player_speed)) - - jax.debug.print("isOnLine: {}", jnp.subtract(new_position_y, slope * new_position_x + b)) - return isOnLine + x_step = abs(jnp.subtract(state.player_car.position.y, slope * (state.player_car.position.x) + b)) + y_step = abs(jnp.subtract(state.player_car.position.y - player_speed, slope * state.player_car.position.x + b)) + prefer_y = jnp.less_equal(y_step, x_step) + return jnp.logical_or( + jnp.logical_and(turn == 1, prefer_y), + jnp.logical_and(turn == 2, jnp.logical_not(prefer_y)), + ) @partial(jax.jit, static_argnums=(0,)) def _landing_in_water(self, state: UpNDownState, new_position_x: chex.Array, new_position_y: chex.Array) -> chex.Array: - road_A_x = ((new_position_y - self.consts.FIRST_TRACK_CORNERS_Y[state.player_car.road_index_A]) / (self.consts.FIRST_TRACK_CORNERS_Y[state.player_car.road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_Y[state.player_car.road_index_A])) * (self.consts.FIRST_TRACK_CORNERS_X[state.player_car.road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_X[state.player_car.road_index_A]) + self.consts.FIRST_TRACK_CORNERS_X[state.player_car.road_index_A] - road_B_x = ((new_position_y - self.consts.SECOND_TRACK_CORNERS_Y[state.player_car.road_index_B]) / (self.consts.SECOND_TRACK_CORNERS_Y[state.player_car.road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_Y[state.player_car.road_index_B])) * (self.consts.SECOND_TRACK_CORNERS_X[state.player_car.road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[state.player_car.road_index_B]) + self.consts.SECOND_TRACK_CORNERS_X[state.player_car.road_index_B] + road_A_x = ((new_position_y - self.consts.TRACK_CORNERS_Y[state.player_car.road_index_A]) / (self.consts.TRACK_CORNERS_Y[state.player_car.road_index_A+1] - self.consts.TRACK_CORNERS_Y[state.player_car.road_index_A])) * (self.consts.FIRST_TRACK_CORNERS_X[state.player_car.road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_X[state.player_car.road_index_A]) + self.consts.FIRST_TRACK_CORNERS_X[state.player_car.road_index_A] + road_B_x = ((new_position_y - self.consts.TRACK_CORNERS_Y[state.player_car.road_index_B]) / (self.consts.TRACK_CORNERS_Y[state.player_car.road_index_B+1] - self.consts.TRACK_CORNERS_Y[state.player_car.road_index_B])) * (self.consts.SECOND_TRACK_CORNERS_X[state.player_car.road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[state.player_car.road_index_B]) + self.consts.SECOND_TRACK_CORNERS_X[state.player_car.road_index_B] distance_to_road_A = jnp.abs(new_position_x - road_A_x) distance_to_road_B = jnp.abs(new_position_x - road_B_x) landing_in_Water = jnp.logical_and(distance_to_road_A > self.consts.LANDING_ZONE, distance_to_road_B > self.consts.LANDING_ZONE) @@ -187,17 +188,15 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: operand=car_direction_x, ) - is_landing = jnp.logical_and(state.jump_cooldown == 1, jump_cooldown == 0) - ##calculate new position with speed (TODO: calculate better speed) player_y = jax.lax.cond( - state.step_counter % (16/ speed_divider) == 8 / speed_divider, + jnp.logical_and((state.step_counter % (16/ speed_divider) == 8 / speed_divider), player_speed != 0,), lambda s: jax.lax.cond( is_jumping, lambda s: state.player_car.position.y + jax.lax.abs(player_speed) / player_speed * -1, lambda s: jax.lax.cond( - self._isOnLine(state, state.player_car.position.x, s + jax.lax.abs(player_speed) / player_speed * -1, 1), + self._isOnLine(state, jax.lax.abs(player_speed) / player_speed, 1), lambda s: s + jax.lax.abs(player_speed) / player_speed * -1, lambda s: jnp.array(s, float), operand=state.player_car.position.y, @@ -207,12 +206,12 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: operand=state.player_car.position.y, ) player_x = jax.lax.cond( - state.step_counter % (16/ speed_divider) == 0, + jnp.logical_and((state.step_counter % (16/ speed_divider) == 0), player_speed != 0,), lambda s: jax.lax.cond( is_jumping, lambda s: s + jax.lax.abs(player_speed) / player_speed * car_direction_x, lambda s: jax.lax.cond( - self._isOnLine(state, s + jax.lax.abs(player_speed) / player_speed * car_direction_x, player_y, 1), + self._isOnLine(state, jax.lax.abs(player_speed) / player_speed, 2), lambda s: s + jax.lax.abs(player_speed) / player_speed * car_direction_x, lambda s: jnp.array(s, float), operand=state.player_car.position.x, @@ -224,7 +223,7 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: ##if y not on mx +b then no move - jax.debug.print("Player X: {}, Player Y: {}, car_direction_x: {}", player_x, player_y, car_direction_x) + landing_in_Water, between_roads, road_A_x, road_B_x = self._landing_in_water(state, player_x, player_y) landing_in_Water = jnp.logical_and(is_landing, landing_in_Water) @@ -251,18 +250,18 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: current_road == 2, lambda s: road_index_A, lambda s: jax.lax.cond( - self.consts.FIRST_TRACK_CORNERS_Y[road_index_A] < player_y, + self.consts.TRACK_CORNERS_Y[road_index_A] < player_y, lambda s: road_index_A - 1, lambda s: jax.lax.cond( - len(self.consts.FIRST_TRACK_CORNERS_Y) == road_index_A + 1, + len(self.consts.TRACK_CORNERS_Y) == road_index_A + 1, lambda s: jax.lax.cond( - self.consts.FIRST_TRACK_CORNERS_Y[0] > player_y, + self.consts.TRACK_CORNERS_Y[0] > player_y, lambda s: 0, lambda s: road_index_A, operand=None, ), lambda s: jax.lax.cond( - self.consts.FIRST_TRACK_CORNERS_Y[road_index_A+1] > player_y, + self.consts.TRACK_CORNERS_Y[road_index_A+1] > player_y, lambda s: road_index_A + 1, lambda s: road_index_A, operand=None, @@ -278,18 +277,18 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: current_road == 2, lambda s: road_index_B, lambda s: jax.lax.cond( - self.consts.SECOND_TRACK_CORNERS_Y[road_index_B] < player_y, + self.consts.TRACK_CORNERS_Y[road_index_B] < player_y, lambda s: road_index_B - 1, lambda s: jax.lax.cond( - len(self.consts.SECOND_TRACK_CORNERS_Y) == road_index_B + 1, + len(self.consts.TRACK_CORNERS_Y) == road_index_B + 1, lambda s: jax.lax.cond( - self.consts.SECOND_TRACK_CORNERS_Y[0] > player_y, + self.consts.TRACK_CORNERS_Y[0] > player_y, lambda s: 0, lambda s: road_index_B, operand=None, ), lambda s: jax.lax.cond( - self.consts.SECOND_TRACK_CORNERS_Y[road_index_B+1] > player_y, + self.consts.TRACK_CORNERS_Y[road_index_B+1] > player_y, lambda s: road_index_B + 1, lambda s: road_index_B, operand=None, @@ -301,7 +300,7 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: operand=None, ) - + jax.debug.print("Player X: {}, Player Y: {}, car_direction_x: {}", player_x, player_y, car_direction_x) #jax.debug.print("Player X: {}, Player Y: {}, on road: {}, jumping: {}, speed: {}, road index A: {}, road index B: {}, current road: {}", player_x, player_y, is_on_road, is_jumping, player_speed, road_index_A, road_index_B, current_road) From 729185db5669480b2e87446f4935259e81810d76 Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sat, 13 Dec 2025 20:40:38 +0100 Subject: [PATCH 51/76] add collectibles to the game --- src/jaxatari/games/jax_upndown.py | 624 +++++++++++++++++++++++++++++- 1 file changed, 612 insertions(+), 12 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index a75917186..e469364bf 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -26,6 +26,46 @@ class UpNDownConstants(NamedTuple): SECOND_TRACK_CORNERS_X: chex.Array = jnp.array([115, 75, 20, 75, 133, 75, 22, 37, 63, 27, 66, 30, 63, 24, 60, 38, 38, 75, 131, 111, 150, 118, 118, 98, 150, 115]) PLAYER_SIZE: Tuple[int, int] = (4, 16) INITIAL_ROAD_POS_Y: int = 25 + # Flag constants - 8 flags with different colors matching the top row + NUM_FLAGS: int = 8 + FLAG_SIZE: Tuple[int, int] = (11, 6) # height, width of the flag sprite + FLAG_POLE_SIZE: Tuple[int, int] = (7, 2) # height, width of the pole sprite + # Flag colors as RGBA values (matching the top row from left to right) + FLAG_COLORS: chex.Array = jnp.array([ + [184, 50, 50, 255], # Red + [181, 83, 40, 255], # Orange + [162, 98, 33, 255], # Dark orange + [134, 134, 29, 255], # Yellow/olive + [200, 72, 72, 255], # Pink (original) + [168, 48, 143, 255], # Magenta + [125, 48, 173, 255], # Purple + [78, 50, 181, 255], # Blue + ]) + # Top display positions for each flag (x coordinates where blackout squares appear) + FLAG_TOP_X_POSITIONS: chex.Array = jnp.array([13, 30, 47, 64, 82, 98, 118, 132]) + FLAG_TOP_Y: int = 20 + FLAG_BLACKOUT_SIZE: Tuple[int, int] = (14, 14) # Size of blackout square + FLAG_COLLECTION_SCORE: int = 75 # Points awarded for collecting a flag + PICKUP_SCORE: int = 100 # Points awarded for jumping on a pickup truck + FLAG_CARRIER_SCORE: int = 125 # Points awarded for jumping on a flag carrier + CAMARO_SCORE: int = 150 # Points awarded for jumping on a camaro + TRUCK_SCORE: int = 175 # Points awarded for jumping on a truck + # Collectible constants - unified dynamic spawning + MAX_COLLECTIBLES: int = 2 # Maximum collectibles that can exist at once (pool of mixed types) + COLLECTIBLE_SIZE: Tuple[int, int] = (8, 8) # height, width of collectible sprite + COLLECTIBLE_SPAWN_INTERVAL: int = 200 # Steps between spawn attempts + COLLECTIBLE_DESPAWN_DISTANCE: int = 500 # Distance beyond which collectibles despawn + # Collectible types (indices for type field) + COLLECTIBLE_TYPE_CHERRY: int = 0 + COLLECTIBLE_TYPE_BALLOON: int = 1 + COLLECTIBLE_TYPE_LOLLYPOP: int = 2 + COLLECTIBLE_TYPE_ICE_CREAM: int = 3 + # Collectible type spawn probabilities (must sum to 100) + COLLECTIBLE_SPAWN_PROBABILITIES: chex.Array = jnp.array([40, 20, 20, 20], dtype=jnp.int32) # Cherry: 40%, Balloon: 20%, Lollypop: 20%, IceCream: 20% + # Collectible type scores + COLLECTIBLE_SCORES: chex.Array = jnp.array([50, 65, 70, 75], dtype=jnp.int32) # [cherry, balloon, lollypop, ice_cream] + # Shared collectible colors + COLLECTIBLE_COLORS: chex.Array = FLAG_COLORS @@ -45,6 +85,27 @@ class Car(NamedTuple): road_index_B: chex.Array direction_x: chex.Array +class Flag(NamedTuple): + """Represents a collectible flag on the road.""" + y: chex.Array # Y position in world coordinates (like player_car.position.y) + road: chex.Array # Which road the flag is on (0 or 1) + road_segment: chex.Array # Which road segment index the flag is on + color_idx: chex.Array # Index into FLAG_COLORS array + collected: chex.Array # Whether this flag has been collected + +class Collectible(NamedTuple): + """Represents a dynamically spawning collectible item on the road. + + Can be any type: cherry (0), balloon (1), lollypop (2), or ice cream (3). + The type determines the sprite and point value. + """ + y: chex.Array # Y position in world coordinates + x: chex.Array # X position on the road + road: chex.Array # Which road the collectible is on (0 or 1) + color_idx: chex.Array # Index into COLLECTIBLE_COLORS array + type_id: chex.Array # Type of collectible (0=cherry, 1=balloon, 2=lollypop, 3=ice_cream) + active: chex.Array # Whether this collectible slot is active (spawned) + class UpNDownState(NamedTuple): score: chex.Array difficulty: chex.Array @@ -53,6 +114,12 @@ class UpNDownState(NamedTuple): is_on_road: chex.Array player_car: Car step_counter: chex.Array + # Flag state - tracks all 8 flags + flags: Flag # Contains arrays of size NUM_FLAGS for each field + flags_collected_mask: chex.Array # Boolean mask of which flag colors have been collected (size NUM_FLAGS) + # Collectible state - dynamic spawning (mixed types: cherry, balloon, lollypop, ice cream) + collectibles: Collectible # Contains arrays of size MAX_COLLECTIBLES for each field + collectible_spawn_timer: chex.Array # Counter for collectible spawn timing @@ -60,12 +127,6 @@ class UpNDownState(NamedTuple): class UpNDownObservation(NamedTuple): player: EntityPosition -class Collectible(NamedTuple): - position: EntityPosition - type: chex.Array - value: chex.Array - - class UpNDownInfo(NamedTuple): time: jnp.ndarray @@ -104,6 +165,23 @@ def _getSlopeAndB(self, state: UpNDownState) -> chex.Array: b = tracky[roadIndex] - slope * trackx[roadIndex] return slope, b + @partial(jax.jit, static_argnums=(0,)) + def _get_x_on_road(self, y: chex.Array, road_segment: chex.Array, track_corners_x: chex.Array) -> chex.Array: + """Calculate the X position on a road given a Y coordinate and road segment.""" + y1 = self.consts.TRACK_CORNERS_Y[road_segment] + y2 = self.consts.TRACK_CORNERS_Y[road_segment + 1] + x1 = track_corners_x[road_segment] + x2 = track_corners_x[road_segment + 1] + + # Linear interpolation: x = x1 + (y - y1) * (x2 - x1) / (y2 - y1) + t = jax.lax.cond( + y2 != y1, + lambda _: (y - y1) / (y2 - y1), + lambda _: 0.0, + operand=None, + ) + return x1 + t * (x2 - x1) + @partial(jax.jit, static_argnums=(0,)) def _isOnLine(self, state: UpNDownState, player_speed: chex.Array, turn: chex.Array) -> chex.Array: slope, b = self._getSlopeAndB(state) @@ -125,6 +203,223 @@ def _landing_in_water(self, state: UpNDownState, new_position_x: chex.Array, new between_roads = jnp.logical_and(new_position_x > jnp.minimum(road_A_x, road_B_x), new_position_x < jnp.maximum(road_A_x, road_B_x)) return landing_in_Water, between_roads, road_A_x, road_B_x + def _flag_step(self, state: UpNDownState, new_player_y: chex.Array, player_x: chex.Array, current_road: chex.Array) -> Tuple[Flag, chex.Array, chex.Array]: + """Update flag collection state and score. + + Args: + state: Current game state + new_player_y: Updated player Y position after movement + player_x: Current player X position + current_road: Current road player is on + + Returns: + Tuple of (updated_flags, score_delta, flags_collected_mask) + """ + # Check collision for each flag + def check_flag_collision(flag_idx): + flag_y = state.flags.y[flag_idx] + flag_road = state.flags.road[flag_idx] + flag_collected = state.flags.collected[flag_idx] + + # Calculate flag X position on its road + flag_segment = state.flags.road_segment[flag_idx] + flag_x = jax.lax.cond( + flag_road == 0, + lambda _: self._get_x_on_road(flag_y, flag_segment, self.consts.FIRST_TRACK_CORNERS_X), + lambda _: self._get_x_on_road(flag_y, flag_segment, self.consts.SECOND_TRACK_CORNERS_X), + operand=None, + ) + + # Check if player is close enough to collect the flag + y_distance = jnp.abs(new_player_y - flag_y) + x_distance = jnp.abs(player_x - flag_x) + same_road = jnp.logical_or( + jnp.logical_and(current_road == 0, flag_road == 0), + jnp.logical_and(current_road == 1, flag_road == 1), + ) + + collision = jnp.logical_and( + jnp.logical_and(y_distance < 5, x_distance < 5), #change the distance threshold if needed + jnp.logical_and(same_road, ~flag_collected) + ) + return collision + + new_collections = jax.vmap(check_flag_collision)(jnp.arange(self.consts.NUM_FLAGS)) + + # Update flags collected state + new_flags_collected = jnp.logical_or(state.flags.collected, new_collections) + new_flags_collected_mask = jnp.logical_or(state.flags_collected_mask, new_collections) + + # Update score based on collected flags + flag_score = jnp.sum(new_collections.astype(jnp.int32) * self.consts.FLAG_COLLECTION_SCORE) + + new_flags = Flag( + y=state.flags.y, + road=state.flags.road, + road_segment=state.flags.road_segment, + color_idx=state.flags.color_idx, + collected=new_flags_collected, + ) + + return new_flags, flag_score, new_flags_collected_mask + + def _collectible_step(self, state: UpNDownState, new_player_y: chex.Array, player_x: chex.Array, current_road: chex.Array) -> Tuple[Collectible, chex.Array, chex.Array]: + """Update collectible spawning, despawning, and collection (unified for all types). + + Handles mixed-type collectibles (cherry, balloon, lollypop, ice cream) in a single pool. + Type is randomized on spawn with probabilities defined in COLLECTIBLE_SPAWN_PROBABILITIES. + + Args: + state: Current game state + new_player_y: Updated player Y position after movement + player_x: Current player X position + current_road: Current road player is on + + Returns: + Tuple of (updated_collectibles, score_delta, new_spawn_timer) + """ + # Collectible spawning logic - decrement timer and spawn when ready + new_collectible_timer = jax.lax.cond( + state.collectible_spawn_timer <= 0, + lambda _: self.consts.COLLECTIBLE_SPAWN_INTERVAL, + lambda _: state.collectible_spawn_timer - 1, + operand=None, + ) + + # Attempt to spawn when timer hits 0 + should_spawn = state.collectible_spawn_timer <= 0 + + # Find first inactive collectible slot + def find_inactive_idx(collectibles_in): + inactive_mask = ~collectibles_in.active + first_inactive = jnp.argmax(inactive_mask.astype(jnp.int32)) + has_inactive = jnp.any(inactive_mask) + return jax.lax.cond( + has_inactive, + lambda _: first_inactive, + lambda _: jnp.array(0, dtype=jnp.int32), + operand=None, + ), has_inactive + + spawn_idx, has_inactive_slot = find_inactive_idx(state.collectibles) + + # Generate random spawn position using fold_in for deterministic randomness + base_key = jax.random.PRNGKey(0) + key_for_spawn = jax.random.fold_in(base_key, state.step_counter) + key1, key2, key3, key4, key5 = jax.random.split(key_for_spawn, 5) + y_spawn = jax.random.uniform(key1, minval=-900.0, maxval=-100.0) + road_spawn = jnp.array(jax.random.randint(key2, shape=(), minval=0, maxval=2), dtype=jnp.int32) + color_spawn = jnp.array(jax.random.randint(key3, shape=(), minval=0, maxval=len(self.consts.COLLECTIBLE_COLORS)), dtype=jnp.int32) + + # Randomly select collectible type based on spawn probabilities + # Convert probabilities (%) to cumulative distribution for sampling + rand_type = jax.random.uniform(key4, minval=0.0, maxval=100.0) + + # Use cumulative probabilities: cherry [0-40], balloon [40-60], lollypop [60-80], ice_cream [80-100] + def select_type(rand_val): + # Returns 0=cherry, 1=balloon, 2=lollypop, 3=ice_cream + type_id = jnp.where( + rand_val < self.consts.COLLECTIBLE_SPAWN_PROBABILITIES[0], + jnp.int32(self.consts.COLLECTIBLE_TYPE_CHERRY), + jnp.where( + rand_val < self.consts.COLLECTIBLE_SPAWN_PROBABILITIES[1], + jnp.int32(self.consts.COLLECTIBLE_TYPE_BALLOON), + jnp.where( + rand_val < self.consts.COLLECTIBLE_SPAWN_PROBABILITIES[2], + jnp.int32(self.consts.COLLECTIBLE_TYPE_LOLLYPOP), + jnp.int32(self.consts.COLLECTIBLE_TYPE_ICE_CREAM) + ) + ) + ) + return type_id + + type_id_spawn = select_type(rand_type) + + # Calculate X position on road + def get_road_segment(y): + segments = jnp.sum(self.consts.TRACK_CORNERS_Y > y) + return jnp.clip(segments - 1, 0, len(self.consts.TRACK_CORNERS_Y) - 2) + + segment_spawn = get_road_segment(y_spawn) + x_spawn = jax.lax.cond( + road_spawn == 0, + lambda _: self._get_x_on_road(y_spawn, segment_spawn, self.consts.FIRST_TRACK_CORNERS_X), + lambda _: self._get_x_on_road(y_spawn, segment_spawn, self.consts.SECOND_TRACK_CORNERS_X), + operand=None, + ) + + # Create mask for which collectibles to update + update_mask = (jnp.arange(self.consts.MAX_COLLECTIBLES) == spawn_idx) & should_spawn & has_inactive_slot + + # Update collectibles with proper masking + updated_collectibles = Collectible( + y=jnp.where(update_mask, y_spawn, state.collectibles.y), + x=jnp.where(update_mask, x_spawn, state.collectibles.x), + road=jnp.where(update_mask, road_spawn, state.collectibles.road), + color_idx=jnp.where(update_mask, color_spawn, state.collectibles.color_idx), + type_id=jnp.where(update_mask, type_id_spawn, state.collectibles.type_id), + active=jnp.where(update_mask, True, state.collectibles.active), + ) + + # Despawn logic - remove collectibles too far from player + def check_despawn(idx): + c_y = updated_collectibles.y[idx] + c_active = updated_collectibles.active[idx] + distance = jnp.abs(new_player_y - c_y) + too_far = distance > self.consts.COLLECTIBLE_DESPAWN_DISTANCE + should_despawn = jnp.logical_and(c_active, too_far) + return should_despawn + + despawn_mask = jax.vmap(check_despawn)(jnp.arange(self.consts.MAX_COLLECTIBLES)) + new_active = jnp.logical_and(updated_collectibles.active, ~despawn_mask) + + # Collision detection + def check_collision(idx): + c_y = updated_collectibles.y[idx] + c_x = updated_collectibles.x[idx] + c_road = updated_collectibles.road[idx] + c_active = updated_collectibles.active[idx] + + y_distance = jnp.abs(new_player_y - c_y) + x_distance = jnp.abs(player_x - c_x) + same_road = jnp.logical_or( + jnp.logical_and(current_road == 0, c_road == 0), + jnp.logical_and(current_road == 1, c_road == 1), + ) + + collision = jnp.logical_and( + jnp.logical_and(y_distance < 5, x_distance < 5), + jnp.logical_and(same_road, c_active) + ) + return collision + + collections = jax.vmap(check_collision)(jnp.arange(self.consts.MAX_COLLECTIBLES)) + + # Deactivate collected items + new_active = jnp.logical_and(new_active, ~collections) + + # Update score - use type_id to look up score value + def get_collection_score(idx): + is_collected = collections[idx] + type_id = updated_collectibles.type_id[idx] + # Look up score based on type_id using array indexing + score = self.consts.COLLECTIBLE_SCORES[type_id] + return jnp.where(is_collected, score, 0) + + score_array = jax.vmap(get_collection_score)(jnp.arange(self.consts.MAX_COLLECTIBLES)) + score_delta = jnp.sum(score_array) + + updated_collectibles = Collectible( + y=updated_collectibles.y, + x=updated_collectibles.x, + road=updated_collectibles.road, + color_idx=updated_collectibles.color_idx, + type_id=updated_collectibles.type_id, + active=new_active, + ) + + return updated_collectibles, score_delta, new_collectible_timer + def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: up = jnp.logical_or(action == Action.UP, action == Action.UPFIRE) down = jnp.logical_or(action == Action.DOWN, action == Action.DOWNFIRE) @@ -302,8 +597,9 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: jax.debug.print("Player X: {}, Player Y: {}, car_direction_x: {}", player_x, player_y, car_direction_x) - - #jax.debug.print("Player X: {}, Player Y: {}, on road: {}, jumping: {}, speed: {}, road index A: {}, road index B: {}, current road: {}", player_x, player_y, is_on_road, is_jumping, player_speed, road_index_A, road_index_B, current_road) + # Calculate new player y position after wrapping + new_player_y = -((player_y * -1) % 1036) + return UpNDownState( score=state.score, difficulty=state.difficulty, @@ -313,7 +609,7 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: player_car=Car( position=EntityPosition( x=player_x, - y=-((player_y * -1) % 1036), + y=new_player_y, width=state.player_car.position.width, height=state.player_car.position.height, ), @@ -325,10 +621,105 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: type=state.player_car.type, ), step_counter=state.step_counter + 1, + flags=state.flags, + flags_collected_mask=state.flags_collected_mask, + collectibles=state.collectibles, + collectible_spawn_timer=state.collectible_spawn_timer, + ) + + def _flag_step_main(self, state: UpNDownState) -> UpNDownState: + """Update flag collection state and score.""" + new_player_y = state.player_car.position.y + player_x = state.player_car.position.x + current_road = state.player_car.current_road + + new_flags, flag_score, new_flags_collected_mask = self._flag_step( + state, new_player_y, player_x, current_road + ) + + return UpNDownState( + score=state.score + flag_score, + difficulty=state.difficulty, + jump_cooldown=state.jump_cooldown, + is_jumping=state.is_jumping, + is_on_road=state.is_on_road, + player_car=state.player_car, + step_counter=state.step_counter, + flags=new_flags, + flags_collected_mask=new_flags_collected_mask, + collectibles=state.collectibles, + collectible_spawn_timer=state.collectible_spawn_timer, + ) + + def _collectible_step_main(self, state: UpNDownState) -> UpNDownState: + """Update collectible spawning, despawning, and collection.""" + new_player_y = state.player_car.position.y + player_x = state.player_car.position.x + current_road = state.player_car.current_road + + updated_collectibles, collectible_score, new_collectible_timer = self._collectible_step( + state, new_player_y, player_x, current_road + ) + + return UpNDownState( + score=state.score + collectible_score, + difficulty=state.difficulty, + jump_cooldown=state.jump_cooldown, + is_jumping=state.is_jumping, + is_on_road=state.is_on_road, + player_car=state.player_car, + step_counter=state.step_counter, + flags=state.flags, + flags_collected_mask=state.flags_collected_mask, + collectibles=updated_collectibles, + collectible_spawn_timer=new_collectible_timer, ) def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: + # Initialize flags at random positions along the track + # Use key for randomness if provided, otherwise use default positions + if key is None: + key = jax.random.PRNGKey(42) + + # Evenly spread flags along the track with small jitter + key, subkey = jax.random.split(key) + base_y = jnp.linspace(-900.0, -100.0, self.consts.NUM_FLAGS) + jitter = jax.random.uniform(subkey, shape=(self.consts.NUM_FLAGS,), minval=-40.0, maxval=40.0) + flag_y_offsets = base_y + jitter + + # Alternate roads 0/1 for variety + flag_roads = jnp.arange(self.consts.NUM_FLAGS) % 2 + + # Calculate which road segment each flag is on based on Y position + def get_road_segment(y): + # Find the segment where TRACK_CORNERS_Y[i] > y >= TRACK_CORNERS_Y[i+1] + segments = jnp.sum(self.consts.TRACK_CORNERS_Y > y) + return jnp.clip(segments - 1, 0, len(self.consts.TRACK_CORNERS_Y) - 2) + + flag_segments = jax.vmap(get_road_segment)(flag_y_offsets) + + # Each flag color index corresponds to its position (0-7) + flag_color_indices = jnp.arange(self.consts.NUM_FLAGS) + + flags = Flag( + y=flag_y_offsets, + road=flag_roads, + road_segment=flag_segments, + color_idx=flag_color_indices, + collected=jnp.zeros(self.consts.NUM_FLAGS, dtype=jnp.bool_), + ) + + # Initialize collectibles as all inactive (will spawn dynamically with mixed types) + collectibles = Collectible( + y=jnp.zeros(self.consts.MAX_COLLECTIBLES), + x=jnp.zeros(self.consts.MAX_COLLECTIBLES), + road=jnp.zeros(self.consts.MAX_COLLECTIBLES, dtype=jnp.int32), + color_idx=jnp.zeros(self.consts.MAX_COLLECTIBLES, dtype=jnp.int32), + type_id=jnp.zeros(self.consts.MAX_COLLECTIBLES, dtype=jnp.int32), + active=jnp.zeros(self.consts.MAX_COLLECTIBLES, dtype=jnp.bool_), + ) + state = UpNDownState( score=0, difficulty=self.consts.DIFFICULTIES[0], @@ -350,6 +741,10 @@ def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: type=0, ), step_counter=jnp.array(0), + flags=flags, + flags_collected_mask=jnp.zeros(self.consts.NUM_FLAGS, dtype=jnp.bool_), + collectibles=collectibles, + collectible_spawn_timer=jnp.array(0, dtype=jnp.int32), ) initial_obs = self._get_observation(state) return initial_obs, state @@ -358,6 +753,8 @@ def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: def step(self, state: UpNDownState, action: chex.Array) -> Tuple[UpNDownObservation, UpNDownState, float, bool, UpNDownInfo]: previous_state = state state = self._player_step(state, action) + state = self._flag_step_main(state) + state = self._collectible_step_main(state) done = self._get_done(state) env_reward = self._get_reward(previous_state, state) @@ -439,9 +836,10 @@ def __init__(self, consts: UpNDownConstants = None): top_block = self._createBackgroundSprite((25, self.config.game_dimensions[1])) bottom_block = self._createBackgroundSprite((16, self.config.game_dimensions[1])) temp_pointer = self._createBackgroundSprite((1, 1)) + blackout_square = self._createBackgroundSprite(self.consts.FLAG_BLACKOUT_SIZE) # 2. Update asset config to include both walls - asset_config, road_files = self._get_asset_config(background, top_block, bottom_block, temp_pointer) + asset_config, road_files = self._get_asset_config(background, top_block, bottom_block, temp_pointer, blackout_square) sprite_path = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/" # 3. Make a single call to the setup function @@ -459,6 +857,28 @@ def __init__(self, consts: UpNDownConstants = None): repeats = max(1, math.ceil(self.view_height / road_cycle) + 2) self._road_tile_offsets = jnp.arange(-repeats, repeats + 1, dtype=jnp.int32) * jnp.int32(self.complete_road_size) self._num_road_tiles = int(self._road_tile_offsets.shape[0]) + + # Precompute flag mask data for recoloring without special-casing pink + self.flag_base_mask = self.SHAPE_MASKS["pink_flag"] + self.flag_solid_mask = self.flag_base_mask != self.jr.TRANSPARENT_ID + self.flag_palette_ids = self._compute_flag_palette_ids() + + # Precompute collectible mask data for recoloring (unified for all types: cherry, balloon, lollypop, ice cream) + self.cherry_base_mask = self.SHAPE_MASKS["cherry"] + self.cherry_solid_mask = self.cherry_base_mask != self.jr.TRANSPARENT_ID + self.cherry_palette_ids = self._compute_flag_palette_ids() + + self.balloon_base_mask = self.SHAPE_MASKS["balloon"] + self.balloon_solid_mask = self.balloon_base_mask != self.jr.TRANSPARENT_ID + self.balloon_palette_ids = self._compute_flag_palette_ids() + + self.lollypop_base_mask = self.SHAPE_MASKS["lollypop"] + self.lollypop_solid_mask = self.lollypop_base_mask != self.jr.TRANSPARENT_ID + self.lollypop_palette_ids = self._compute_flag_palette_ids() + + self.ice_cream_base_mask = self.SHAPE_MASKS["ice_cream"] + self.ice_cream_solid_mask = self.ice_cream_base_mask != self.jr.TRANSPARENT_ID + self.ice_cream_palette_ids = self._compute_flag_palette_ids() def _createBackgroundSprite(self, dimensions: Tuple[int, int]) -> jnp.ndarray: """Creates a procedural background sprite for the game.""" @@ -479,8 +899,21 @@ def _get_road_sprite_sizes(self, road_files: list[str]) -> list: complete_size = int(sum(sizes)) jax.debug.print("Complete road size: {}", complete_size) return sizes, complete_size + + def _find_palette_id(self, rgba: jnp.ndarray) -> int: + """Return palette index for an RGBA color, falling back to first entry if missing.""" + color_rgb = rgba[:3] + palette_rgb = self.PALETTE[:, :3] + matches = jnp.all(palette_rgb == color_rgb, axis=1) + found = jnp.argmax(matches) + # If no match, fallback to 0 (background) to avoid crashes + return int(found) + + def _compute_flag_palette_ids(self) -> jnp.ndarray: + """Precompute palette indices for each flag color without special-casing pink.""" + return jnp.array([self._find_palette_id(color) for color in self.consts.FLAG_COLORS], dtype=jnp.int32) - def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.ndarray, bottomBlockSprite: jnp.ndarray, tempPointer: jnp.ndarray) -> tuple[list, list[str]]: + def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.ndarray, bottomBlockSprite: jnp.ndarray, tempPointer: jnp.ndarray, blackoutSquare: jnp.ndarray) -> tuple[list, list[str]]: """Returns the asset manifest and ordered road files.""" road_dir = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/roads" road_files = sorted( @@ -494,7 +927,16 @@ def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.n {'name': 'player', 'type': 'single', 'file': 'player_car.npy'}, {'name': 'wall_top', 'type': 'procedural', 'data': topBlockSprite}, {'name': 'wall_bottom', 'type': 'procedural', 'data': bottomBlockSprite}, + {'name': 'all_flags_top', 'type': 'single', 'file': 'all_flags_top.npy'}, + {'name': 'all_lives_bottom', 'type': 'single', 'file': 'all_lives_bottom.npy'}, + {'name': 'pink_flag', 'type': 'single', 'file': 'pink_flag.npy'}, + {'name': 'flag_pole', 'type': 'single', 'file': 'flag_pole.npy'}, + {'name': 'cherry', 'type': 'single', 'file': 'cherry.npy'}, + {'name': 'balloon', 'type': 'single', 'file': 'balloon.npy'}, + {'name': 'lollypop', 'type': 'single', 'file': 'lollypop.npy'}, + {'name': 'ice_cream', 'type': 'single', 'file': 'ice_cream_cone.npy'}, {'name': 'tempPointer', 'type': 'procedural', 'data': tempPointer}, + {'name': 'blackout_square', 'type': 'procedural', 'data': blackoutSquare}, ], roads @partial(jax.jit, static_argnums=(0,)) @@ -556,7 +998,165 @@ def combine(i, acc): wall_bottom_mask = self.SHAPE_MASKS["wall_bottom"] raster = self.jr.render_at(raster, 0, 210 - wall_bottom_mask.shape[0], wall_bottom_mask) + all_flags_top_mask = self.SHAPE_MASKS["all_flags_top"] + raster = self.jr.render_at(raster, 10, 20, all_flags_top_mask) + + # Render flags on the road + flag_pole_mask = self.SHAPE_MASKS["flag_pole"] + + def render_flag(carry, flag_idx): + raster = carry + flag_y = state.flags.y[flag_idx] + flag_road = state.flags.road[flag_idx] + flag_segment = state.flags.road_segment[flag_idx] + flag_collected = state.flags.collected[flag_idx] + flag_color_idx = state.flags.color_idx[flag_idx] + + # Calculate flag X position on its road + flag_x = jax.lax.cond( + flag_road == 0, + lambda _: self._get_flag_x_on_road(flag_y, flag_segment, self.consts.FIRST_TRACK_CORNERS_X), + lambda _: self._get_flag_x_on_road(flag_y, flag_segment, self.consts.SECOND_TRACK_CORNERS_X), + operand=None, + ) + + # Calculate screen Y position relative to player + # The player is always rendered at Y=105, so flags scroll based on player position + screen_y = 105 + (flag_y - state.player_car.position.y) + + # Check if flag is visible on screen and not collected + is_visible = jnp.logical_and( + jnp.logical_and(screen_y > 25, screen_y < 195), + ~flag_collected + ) + + # Colorize the base flag mask + color_id = self.flag_palette_ids[flag_color_idx] + colored_flag_mask = jnp.where( + self.flag_solid_mask, + color_id, + self.flag_base_mask, + ) + + # Render flag if visible + raster = jax.lax.cond( + is_visible, + lambda r: self.jr.render_at( + self.jr.render_at(r, flag_x.astype(jnp.int32), screen_y.astype(jnp.int32), colored_flag_mask), + (flag_x + 5).astype(jnp.int32), screen_y.astype(jnp.int32), flag_pole_mask + ), + lambda r: r, + operand=raster, + ) + return raster, None + + raster, _ = jax.lax.scan(render_flag, raster, jnp.arange(self.consts.NUM_FLAGS)) + + # Black out collected flags at the top + blackout_mask = self.SHAPE_MASKS["blackout_square"] + + def render_blackout(carry, flag_idx): + raster = carry + flag_collected = state.flags_collected_mask[flag_idx] + blackout_x = self.consts.FLAG_TOP_X_POSITIONS[flag_idx] + blackout_y = self.consts.FLAG_TOP_Y + + raster = jax.lax.cond( + flag_collected, + lambda r: self.jr.render_at(r, blackout_x, blackout_y, blackout_mask), + lambda r: r, + operand=raster, + ) + return raster, None + + raster, _ = jax.lax.scan(render_blackout, raster, jnp.arange(self.consts.NUM_FLAGS)) + + # Render collectibles (unified for all types: cherry, balloon, lollypop, ice cream) + def render_collectible(carry, collectible_idx): + raster = carry + collectible_y = state.collectibles.y[collectible_idx] + collectible_x = state.collectibles.x[collectible_idx] + collectible_active = state.collectibles.active[collectible_idx] + collectible_color_idx = state.collectibles.color_idx[collectible_idx] + collectible_type_id = state.collectibles.type_id[collectible_idx] + + # Calculate screen Y position relative to player + screen_y = 105 + (collectible_y - state.player_car.position.y) + + # Check if collectible is visible on screen and active + is_visible = jnp.logical_and( + jnp.logical_and(screen_y > 25, screen_y < 195), + collectible_active + ) + + # Select sprite based on type_id + # type_id: 0=cherry, 1=balloon, 2=lollypop, 3=ice_cream + def get_sprite_and_mask(type_id): + cherry_result = (self.cherry_base_mask, self.cherry_solid_mask, self.cherry_palette_ids) + balloon_result = (self.balloon_base_mask, self.balloon_solid_mask, self.balloon_palette_ids) + lollypop_result = (self.lollypop_base_mask, self.lollypop_solid_mask, self.lollypop_palette_ids) + ice_cream_result = (self.ice_cream_base_mask, self.ice_cream_solid_mask, self.ice_cream_palette_ids) + + # Use conditional branching to select sprite + result = jax.lax.cond( + type_id == self.consts.COLLECTIBLE_TYPE_CHERRY, + lambda _: cherry_result, + lambda _: jax.lax.cond( + type_id == self.consts.COLLECTIBLE_TYPE_BALLOON, + lambda _: balloon_result, + lambda _: jax.lax.cond( + type_id == self.consts.COLLECTIBLE_TYPE_LOLLYPOP, + lambda _: lollypop_result, + lambda _: ice_cream_result, + operand=None, + ), + operand=None, + ), + operand=None, + ) + return result + + base_mask, solid_mask, palette_ids = get_sprite_and_mask(collectible_type_id) + + # Only colorize inner pixels, keep black edges (palette ID 0 is black) + color_id = palette_ids[collectible_color_idx] + colored_mask = jnp.where( + (base_mask != self.jr.TRANSPARENT_ID) & (base_mask != 0), + color_id, + base_mask, + ) + + # Render collectible if visible + raster = jax.lax.cond( + is_visible, + lambda r: self.jr.render_at(r, collectible_x.astype(jnp.int32), screen_y.astype(jnp.int32), colored_mask), + lambda r: r, + operand=raster, + ) + return raster, None + + raster, _ = jax.lax.scan(render_collectible, raster, jnp.arange(self.consts.MAX_COLLECTIBLES)) + + all_lives_bottom_mask = self.SHAPE_MASKS["all_lives_bottom"] + raster = self.jr.render_at(raster, 10, 195, all_lives_bottom_mask) + wall_bottom_mask = self.SHAPE_MASKS["tempPointer"] raster = self.jr.render_at(raster, 140, 25, wall_bottom_mask) - return self.jr.render_from_palette(raster, self.PALETTE) \ No newline at end of file + return self.jr.render_from_palette(raster, self.PALETTE) + + def _get_flag_x_on_road(self, y: chex.Array, road_segment: chex.Array, track_corners_x: chex.Array) -> chex.Array: + """Calculate the X position on a road given a Y coordinate and road segment.""" + y1 = self.consts.TRACK_CORNERS_Y[road_segment] + y2 = self.consts.TRACK_CORNERS_Y[road_segment + 1] + x1 = track_corners_x[road_segment] + x2 = track_corners_x[road_segment + 1] + + # Linear interpolation: x = x1 + (y - y1) * (x2 - x1) / (y2 - y1) + t = jax.lax.cond( + y2 != y1, + lambda _: (y - y1) / (y2 - y1), + lambda _: 0.0, + operand=None, + ) + return x1 + t * (x2 - x1) \ No newline at end of file From 301b96ec738c9ae9366a30c45b3555c9230cba9b Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Wed, 17 Dec 2025 16:43:57 +0100 Subject: [PATCH 52/76] add score and display and passive point gain --- src/jaxatari/games/jax_upndown.py | 72 +++++++++++++++++++++++++++++++ 1 file changed, 72 insertions(+) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index e469364bf..a1646fa43 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -114,6 +114,8 @@ class UpNDownState(NamedTuple): is_on_road: chex.Array player_car: Car step_counter: chex.Array + round_started: chex.Array + movement_steps: chex.Array # Flag state - tracks all 8 flags flags: Flag # Contains arrays of size NUM_FLAGS for each field flags_collected_mask: chex.Array # Boolean mask of which flag colors have been collected (size NUM_FLAGS) @@ -621,6 +623,13 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: type=state.player_car.type, ), step_counter=state.step_counter + 1, + round_started=jnp.logical_or(state.round_started, player_speed != 0), + movement_steps=jax.lax.cond( + jnp.logical_or(state.round_started, player_speed != 0), + lambda s: state.movement_steps + 1, + lambda s: state.movement_steps, + operand=None, + ), flags=state.flags, flags_collected_mask=state.flags_collected_mask, collectibles=state.collectibles, @@ -645,6 +654,8 @@ def _flag_step_main(self, state: UpNDownState) -> UpNDownState: is_on_road=state.is_on_road, player_car=state.player_car, step_counter=state.step_counter, + round_started=state.round_started, + movement_steps=state.movement_steps, flags=new_flags, flags_collected_mask=new_flags_collected_mask, collectibles=state.collectibles, @@ -669,12 +680,39 @@ def _collectible_step_main(self, state: UpNDownState) -> UpNDownState: is_on_road=state.is_on_road, player_car=state.player_car, step_counter=state.step_counter, + round_started=state.round_started, + movement_steps=state.movement_steps, flags=state.flags, flags_collected_mask=state.flags_collected_mask, collectibles=updated_collectibles, collectible_spawn_timer=new_collectible_timer, ) + def _passive_score_step_main(self, state: UpNDownState) -> UpNDownState: + """Award passive score every 60 steps after the player has started moving.""" + bonus = jax.lax.cond( + jnp.logical_and(state.round_started, state.movement_steps % 60 == 0), + lambda _: jnp.int32(10), + lambda _: jnp.int32(0), + operand=None, + ) + + return UpNDownState( + score=state.score + bonus, + difficulty=state.difficulty, + jump_cooldown=state.jump_cooldown, + is_jumping=state.is_jumping, + is_on_road=state.is_on_road, + player_car=state.player_car, + step_counter=state.step_counter, + round_started=state.round_started, + movement_steps=state.movement_steps, + flags=state.flags, + flags_collected_mask=state.flags_collected_mask, + collectibles=state.collectibles, + collectible_spawn_timer=state.collectible_spawn_timer, + ) + def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: # Initialize flags at random positions along the track @@ -741,6 +779,8 @@ def get_road_segment(y): type=0, ), step_counter=jnp.array(0), + round_started=jnp.array(False), + movement_steps=jnp.array(0), flags=flags, flags_collected_mask=jnp.zeros(self.consts.NUM_FLAGS, dtype=jnp.bool_), collectibles=collectibles, @@ -753,6 +793,7 @@ def get_road_segment(y): def step(self, state: UpNDownState, action: chex.Array) -> Tuple[UpNDownObservation, UpNDownState, float, bool, UpNDownInfo]: previous_state = state state = self._player_step(state, action) + state = self._passive_score_step_main(state) state = self._flag_step_main(state) state = self._collectible_step_main(state) @@ -880,6 +921,13 @@ def __init__(self, consts: UpNDownConstants = None): self.ice_cream_solid_mask = self.ice_cream_base_mask != self.jr.TRANSPARENT_ID self.ice_cream_palette_ids = self._compute_flag_palette_ids() + # Score rendering helpers + self.score_digit_masks = self.SHAPE_MASKS["score_digits"] + self.score_max_digits = 6 + self.score_digit_spacing = int(self.score_digit_masks.shape[2]) + 1 + self.score_render_y = 6 + self.score_center_x = self.config.game_dimensions[1] // 2 - self.config.game_dimensions[1] // 4 + def _createBackgroundSprite(self, dimensions: Tuple[int, int]) -> jnp.ndarray: """Creates a procedural background sprite for the game.""" height, width = dimensions @@ -929,6 +977,7 @@ def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.n {'name': 'wall_bottom', 'type': 'procedural', 'data': bottomBlockSprite}, {'name': 'all_flags_top', 'type': 'single', 'file': 'all_flags_top.npy'}, {'name': 'all_lives_bottom', 'type': 'single', 'file': 'all_lives_bottom.npy'}, + {'name': 'score_digits', 'type': 'digits', 'pattern': 'score/score_{}.npy'}, {'name': 'pink_flag', 'type': 'single', 'file': 'pink_flag.npy'}, {'name': 'flag_pole', 'type': 'single', 'file': 'flag_pole.npy'}, {'name': 'cherry', 'type': 'single', 'file': 'cherry.npy'}, @@ -1001,6 +1050,29 @@ def combine(i, acc): all_flags_top_mask = self.SHAPE_MASKS["all_flags_top"] raster = self.jr.render_at(raster, 10, 20, all_flags_top_mask) + # Render score centered at the top using dedicated score digit sprites + score_digits = self.jr.int_to_digits(state.score, max_digits=self.score_max_digits) + non_zero_mask = score_digits != 0 + has_non_zero = jnp.any(non_zero_mask) + first_non_zero = jnp.argmax(non_zero_mask) + start_index = jax.lax.select(has_non_zero, first_non_zero, self.score_max_digits - 1) + num_to_render = jax.lax.select(has_non_zero, self.score_max_digits - start_index, 1) + + total_width = num_to_render * self.score_digit_spacing + score_x = self.score_center_x - (total_width // 2) + + raster = self.jr.render_label_selective( + raster, + jnp.int32(score_x), + self.score_render_y, + score_digits, + self.score_digit_masks, + start_index, + num_to_render, + spacing=self.score_digit_spacing, + max_digits_to_render=self.score_max_digits, + ) + # Render flags on the road flag_pole_mask = self.SHAPE_MASKS["flag_pole"] From 5c792c9f3533060d1dc4b81961bf17c9d14a3349 Mon Sep 17 00:00:00 2001 From: shaik05 Date: Sun, 21 Dec 2025 10:07:21 +0100 Subject: [PATCH 53/76] a --- src/jaxatari/games/jax_upndown.py | 166 ++++++++++++++++++++++++------ 1 file changed, 137 insertions(+), 29 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index a1646fa43..5a9987a2c 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -17,7 +17,13 @@ class UpNDownConstants(NamedTuple): DIFFICULTIES: chex.Array = jnp.array([0, 1, 2, 3, 4, 5]) ACTION_REPEAT_PROBS: float = 0.25 MAX_SPEED: int = 4 + INITIAL_LIVES: int = 3 + RESPAWN_Y: int = 0 + RESPAWN_X: int = 30 + RESPAWN_DELAY_FRAMES: int = 60 + WATER_DEATH_PENALTY: int = 0 JUMP_FRAMES: int = 10 + ALL_FLAGS_BONUS: int = 1000 LANDING_ZONE: int = 15 FIRST_ROAD_LENGTH: int = 4 SECOND_ROAD_LENGTH: int = 4 @@ -113,6 +119,9 @@ class UpNDownState(NamedTuple): is_jumping: chex.Array is_on_road: chex.Array player_car: Car + lives: chex.Array + is_dead: chex.Array + respawn_timer: chex.Array step_counter: chex.Array round_started: chex.Array movement_steps: chex.Array @@ -421,11 +430,63 @@ def get_collection_score(idx): ) return updated_collectibles, score_delta, new_collectible_timer + def _death_step(self, state: UpNDownState) -> UpNDownState: + # Player on water road (index 2 assumed water) + died = jnp.logical_and( + state.player_car.current_road == 2, + ~state.is_dead, + ) + + lives = jax.lax.cond( + died, + lambda _: state.lives - 1, + lambda _: state.lives, + operand=None, + ) + lives = jax.lax.cond( + died, + lambda _: state.lives - 1, + lambda _: state.lives, + operand=None, + ) + respawn_timer = jax.lax.cond( + died, + lambda _: jnp.array(self.consts.RESPAWN_DELAY_FRAMES), + lambda _: jnp.maximum(state.respawn_timer - 1, 0), + operand=None, + ) + is_dead = jnp.logical_and( + jnp.logical_or(state.is_dead, died), + respawn_timer > 0) + + player_car = jax.lax.cond( + jnp.logical_and(state.is_dead, respawn_timer == 0), + lambda _: state.player_car._replace( + position=state.player_car.position._replace( + x=jnp.array(self.consts.RESPAWN_X, dtype=jnp.float32), + y=jnp.array(self.consts.RESPAWN_Y, dtype=jnp.float32), + ), + speed=0, + current_road=0, + ), + lambda _: state.player_car, + operand=None, + ) + return state._replace( + lives=lives, + is_dead=is_dead, + respawn_timer=respawn_timer, + player_car=player_car, + ) + def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: up = jnp.logical_or(action == Action.UP, action == Action.UPFIRE) down = jnp.logical_or(action == Action.DOWN, action == Action.DOWNFIRE) jump = jnp.logical_or(action == Action.FIRE, jnp.logical_or(action == Action.UPFIRE, action == Action.DOWNFIRE)) + lives=jnp.array(self.consts.INITIAL_LIVES, dtype=jnp.int32), + is_dead=jnp.array(False), + respawn_timer=jnp.array(0, dtype=jnp.int32), @@ -622,6 +683,9 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: road_index_B=road_index_B, type=state.player_car.type, ), + lives=state.lives, + is_dead=state.is_dead, + respawn_timer=state.respawn_timer, step_counter=state.step_counter + 1, round_started=jnp.logical_or(state.round_started, player_speed != 0), movement_steps=jax.lax.cond( @@ -653,6 +717,9 @@ def _flag_step_main(self, state: UpNDownState) -> UpNDownState: is_jumping=state.is_jumping, is_on_road=state.is_on_road, player_car=state.player_car, + lives=state.lives, + is_dead=state.is_dead, + respawn_timer=state.respawn_timer, step_counter=state.step_counter, round_started=state.round_started, movement_steps=state.movement_steps, @@ -661,6 +728,18 @@ def _flag_step_main(self, state: UpNDownState) -> UpNDownState: collectibles=state.collectibles, collectible_spawn_timer=state.collectible_spawn_timer, ) + def _completion_bonus_step(self, state: UpNDownState) -> UpNDownState: + all_flags_collected = jnp.all(state.flags_collected_mask) + + bonus = jax.lax.cond( + all_flags_collected, + lambda _: self.consts.ALL_FLAGS_BONUS, + lambda _: 0, + operand=None, + ) + return state._replace(score=state.score + bonus,lives=state.lives, + is_dead=state.is_dead, + respawn_timer=state.respawn_timer,) def _collectible_step_main(self, state: UpNDownState) -> UpNDownState: """Update collectible spawning, despawning, and collection.""" @@ -679,6 +758,9 @@ def _collectible_step_main(self, state: UpNDownState) -> UpNDownState: is_jumping=state.is_jumping, is_on_road=state.is_on_road, player_car=state.player_car, + lives=state.lives, + is_dead=state.is_dead, + respawn_timer=state.respawn_timer, step_counter=state.step_counter, round_started=state.round_started, movement_steps=state.movement_steps, @@ -704,6 +786,9 @@ def _passive_score_step_main(self, state: UpNDownState) -> UpNDownState: is_jumping=state.is_jumping, is_on_road=state.is_on_road, player_car=state.player_car, + lives=state.lives, + is_dead=state.is_dead, + respawn_timer=state.respawn_timer, step_counter=state.step_counter, round_started=state.round_started, movement_steps=state.movement_steps, @@ -729,6 +814,9 @@ def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: # Alternate roads 0/1 for variety flag_roads = jnp.arange(self.consts.NUM_FLAGS) % 2 + + + # Calculate which road segment each flag is on based on Y position def get_road_segment(y): # Find the segment where TRACK_CORNERS_Y[i] > y >= TRACK_CORNERS_Y[i+1] @@ -757,35 +845,45 @@ def get_road_segment(y): type_id=jnp.zeros(self.consts.MAX_COLLECTIBLES, dtype=jnp.int32), active=jnp.zeros(self.consts.MAX_COLLECTIBLES, dtype=jnp.bool_), ) - + player_car = Car( + position=EntityPosition( + x=jnp.array(30, dtype=jnp.int32), + y=jnp.array(0, dtype=jnp.int32), + width=jnp.array(self.consts.PLAYER_SIZE[0], dtype=jnp.int32), + height=jnp.array(self.consts.PLAYER_SIZE[1], dtype=jnp.int32), + ), + speed=jnp.array(0, dtype=jnp.int32), + direction_x=jnp.array(0, dtype=jnp.int32), + current_road=jnp.array(0, dtype=jnp.int32), + road_index_A=jnp.array(0, dtype=jnp.int32), + road_index_B=jnp.array(0, dtype=jnp.int32), + type=jnp.array(0, dtype=jnp.int32), + ) state = UpNDownState( - score=0, - difficulty=self.consts.DIFFICULTIES[0], - jump_cooldown=0, - is_jumping=False, - is_on_road=True, - player_car=Car( - position=EntityPosition( - x=30, - y= 0, - width=self.consts.PLAYER_SIZE[0], - height=self.consts.PLAYER_SIZE[1], - ), - speed=0, - direction_x=0, - current_road=0, - road_index_A=0, - road_index_B=0, - type=0, - ), - step_counter=jnp.array(0), - round_started=jnp.array(False), - movement_steps=jnp.array(0), - flags=flags, - flags_collected_mask=jnp.zeros(self.consts.NUM_FLAGS, dtype=jnp.bool_), - collectibles=collectibles, - collectible_spawn_timer=jnp.array(0, dtype=jnp.int32), - ) + score=jnp.array(0, dtype=jnp.int32), + difficulty=jnp.array(self.consts.DIFFICULTIES[0], dtype=jnp.int32), + jump_cooldown=jnp.array(0, dtype=jnp.int32), + is_jumping=jnp.array(False), + is_on_road=jnp.array(True), + + player_car=player_car, + + step_counter=jnp.array(0, dtype=jnp.int32), + round_started=jnp.array(False), + movement_steps=jnp.array(0, dtype=jnp.int32), + + flags=flags, + flags_collected_mask=jnp.zeros(self.consts.NUM_FLAGS, dtype=jnp.bool_), + collectibles=collectibles, + collectible_spawn_timer=jnp.array(0, dtype=jnp.int32), + + # -------- NEW REQUIRED FIELDS -------- + lives=jnp.array(self.consts.INITIAL_LIVES, dtype=jnp.int32), + is_dead=jnp.array(False), + respawn_timer=jnp.array(0, dtype=jnp.int32), + ) + + initial_obs = self._get_observation(state) return initial_obs, state @@ -793,9 +891,14 @@ def get_road_segment(y): def step(self, state: UpNDownState, action: chex.Array) -> Tuple[UpNDownObservation, UpNDownState, float, bool, UpNDownInfo]: previous_state = state state = self._player_step(state, action) + state = self._death_step(state) + state = self._passive_score_step_main(state) state = self._flag_step_main(state) + state = self._completion_bonus_step(state) state = self._collectible_step_main(state) + + done = self._get_done(state) env_reward = self._get_reward(previous_state, state) @@ -860,7 +963,12 @@ def _get_reward(self, previous_state: UpNDownState, state: UpNDownState): @partial(jax.jit, static_argnums=(0,)) def _get_done(self, state: UpNDownState) -> bool: - return jnp.logical_not(True) + return jnp.logical_or( + state.lives <= 0, + jnp.all(state.flags_collected_mask), +) + + class UpNDownRenderer(JAXGameRenderer): def __init__(self, consts: UpNDownConstants = None): From 3f018d4969248eb3a6281f4272a25797353e8d1b Mon Sep 17 00:00:00 2001 From: shaik05 Date: Sun, 21 Dec 2025 10:21:32 +0100 Subject: [PATCH 54/76] movement --- src/jaxatari/games/upndown_interface.py | 53 ------------------------- 1 file changed, 53 deletions(-) delete mode 100644 src/jaxatari/games/upndown_interface.py diff --git a/src/jaxatari/games/upndown_interface.py b/src/jaxatari/games/upndown_interface.py deleted file mode 100644 index 68f8c76fe..000000000 --- a/src/jaxatari/games/upndown_interface.py +++ /dev/null @@ -1,53 +0,0 @@ -import jax -import jax.numpy as jnp -import matplotlib.pyplot as plt - -from jaxatari.environment import JAXAtariAction as Action -from upndown import JaxUpNDown, UpNDownConstants # <-- your game file - -def visualize_frame(frame: jnp.ndarray): - """Render an RGB frame using matplotlib.""" - plt.imshow(frame.astype(jnp.uint8)) - plt.axis("off") - plt.show(block=False) - plt.pause(0.05) - plt.clf() - - -def main(): - # Initialize environment - env = JaxUpNDown(UpNDownConstants()) - - # Reset environment - obs, state = env.reset() - print("Initial observation:", obs) - - # Display initial render - frame = env.render(state) - visualize_frame(frame) - - # Create a random key for sampling actions - key = jax.random.PRNGKey(0) - - # Run for 50 steps - for step in range(50): - key, subkey = jax.random.split(key) - # Choose a random action from action space - action = jax.random.choice(subkey, jnp.arange(len(env.action_set))) - - obs, state, reward, done, info = env.step(state, action) - - # Render and display - frame = env.render(state) - visualize_frame(frame) - - print(f"Step {step}: action={env.action_set[int(action)]}, reward={reward}, done={done}") - - if bool(done): - print("Game over — resetting environment.") - obs, state = env.reset() - - plt.close() - -if __name__ == "__main__": - main() From 7053aeb499617d9fea0b4af381d979dc5a941057 Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Thu, 18 Dec 2025 17:34:33 +0100 Subject: [PATCH 55/76] add enemy cars & enemy logic to game --- src/jaxatari/games/jax_upndown.py | 796 +++++++++++++++++++++--------- 1 file changed, 559 insertions(+), 237 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 5a9987a2c..648859a53 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -16,12 +16,24 @@ class UpNDownConstants(NamedTuple): FRAME_SKIP: int = 4 DIFFICULTIES: chex.Array = jnp.array([0, 1, 2, 3, 4, 5]) ACTION_REPEAT_PROBS: float = 0.25 - MAX_SPEED: int = 4 - INITIAL_LIVES: int = 3 - RESPAWN_Y: int = 0 - RESPAWN_X: int = 30 - RESPAWN_DELAY_FRAMES: int = 60 - WATER_DEATH_PENALTY: int = 0 + MAX_SPEED: int = 5 + # Enemy spawning and movement + MAX_ENEMY_CARS: int = 6 + ENEMY_SPAWN_INTERVAL: int = 80 + ENEMY_DESPAWN_DISTANCE: int = 300 + ENEMY_SPEED_MIN: int = 2 + ENEMY_SPEED_MAX: int = 5 + ENEMY_DIRECTION_SWITCH_PROB: float = 0.005 + ENEMY_OFFSCREEN_SPAWN_OFFSET: float = 100.0 + ENEMY_MIN_SPAWN_GAP: float = 40.0 + ENEMY_MAX_AGE: int = 900 + INITIAL_ENEMY_COUNT: int = 3 + INITIAL_ENEMY_BASE_OFFSET: float = 40.0 + INITIAL_ENEMY_GAP: float = 50.0 + ENEMY_TYPE_CAMERO: int = 0 + ENEMY_TYPE_FLAG_CARRIER: int = 1 + ENEMY_TYPE_PICKUP: int = 2 + ENEMY_TYPE_TRUCK: int = 3 JUMP_FRAMES: int = 10 ALL_FLAGS_BONUS: int = 1000 LANDING_ZONE: int = 15 @@ -112,6 +124,19 @@ class Collectible(NamedTuple): type_id: chex.Array # Type of collectible (0=cherry, 1=balloon, 2=lollypop, 3=ice_cream) active: chex.Array # Whether this collectible slot is active (spawned) + +class EnemyCars(NamedTuple): + """Pool of enemy cars that share the same road-following logic as the player.""" + position: EntityPosition # vectorized position fields, size MAX_ENEMY_CARS + speed: chex.Array # signed speed per car + type: chex.Array # type id per car + current_road: chex.Array + road_index_A: chex.Array + road_index_B: chex.Array + direction_x: chex.Array + active: chex.Array + age: chex.Array + class UpNDownState(NamedTuple): score: chex.Array difficulty: chex.Array @@ -131,6 +156,9 @@ class UpNDownState(NamedTuple): # Collectible state - dynamic spawning (mixed types: cherry, balloon, lollypop, ice cream) collectibles: Collectible # Contains arrays of size MAX_COLLECTIBLES for each field collectible_spawn_timer: chex.Array # Counter for collectible spawn timing + # Enemy cars - dynamic spawning and movement + enemy_cars: EnemyCars + enemy_spawn_timer: chex.Array @@ -161,20 +189,39 @@ def __init__(self, consts: UpNDownConstants = None, reward_funcs: list[callable] self.obs_size = 3*4+1+1 @partial(jax.jit, static_argnums=(0,)) - def _getSlopeAndB(self, state: UpNDownState) -> chex.Array: - trackx, tracky, roadIndex = jax.lax.cond( - state.player_car.current_road == 0, - lambda s: (self.consts.FIRST_TRACK_CORNERS_X, self.consts.TRACK_CORNERS_Y, state.player_car.road_index_A), - lambda s: (self.consts.SECOND_TRACK_CORNERS_X, self.consts.TRACK_CORNERS_Y, state.player_car.road_index_B), - operand=None,) + def _get_slope_and_intercept_from_indices(self, current_road: chex.Array, road_index_A: chex.Array, road_index_B: chex.Array) -> Tuple[chex.Array, chex.Array]: + trackx, tracky, road_index = jax.lax.cond( + current_road == 0, + lambda _: (self.consts.FIRST_TRACK_CORNERS_X, self.consts.TRACK_CORNERS_Y, road_index_A), + lambda _: (self.consts.SECOND_TRACK_CORNERS_X, self.consts.TRACK_CORNERS_Y, road_index_B), + operand=None, + ) slope = jax.lax.cond( - trackx[roadIndex+1] - trackx[roadIndex] != 0, - lambda s: (tracky[roadIndex+1] - tracky[roadIndex]) / (trackx[roadIndex+1] - trackx[roadIndex]), - lambda s: 300.0, + trackx[road_index+1] - trackx[road_index] != 0, + lambda _: (tracky[road_index+1] - tracky[road_index]) / (trackx[road_index+1] - trackx[road_index]), + lambda _: 300.0, operand=None, ) - b = tracky[roadIndex] - slope * trackx[roadIndex] + b = tracky[road_index] - slope * trackx[road_index] return slope, b + + @partial(jax.jit, static_argnums=(0,)) + def _getSlopeAndB(self, state: UpNDownState) -> chex.Array: + return self._get_slope_and_intercept_from_indices( + state.player_car.current_road, + state.player_car.road_index_A, + state.player_car.road_index_B, + ) + + @partial(jax.jit, static_argnums=(0,)) + def _is_on_line_for_position(self, position: EntityPosition, slope: chex.Array, b: chex.Array, player_speed: chex.Array, turn: chex.Array) -> chex.Array: + x_step = abs(jnp.subtract(position.y, slope * (position.x) + b)) + y_step = abs(jnp.subtract(position.y - player_speed, slope * position.x + b)) + prefer_y = jnp.less_equal(y_step, x_step) + return jnp.logical_or( + jnp.logical_and(turn == 1, prefer_y), + jnp.logical_and(turn == 2, jnp.logical_not(prefer_y)), + ) @partial(jax.jit, static_argnums=(0,)) def _get_x_on_road(self, y: chex.Array, road_segment: chex.Array, track_corners_x: chex.Array) -> chex.Array: @@ -196,13 +243,7 @@ def _get_x_on_road(self, y: chex.Array, road_segment: chex.Array, track_corners_ @partial(jax.jit, static_argnums=(0,)) def _isOnLine(self, state: UpNDownState, player_speed: chex.Array, turn: chex.Array) -> chex.Array: slope, b = self._getSlopeAndB(state) - x_step = abs(jnp.subtract(state.player_car.position.y, slope * (state.player_car.position.x) + b)) - y_step = abs(jnp.subtract(state.player_car.position.y - player_speed, slope * state.player_car.position.x + b)) - prefer_y = jnp.less_equal(y_step, x_step) - return jnp.logical_or( - jnp.logical_and(turn == 1, prefer_y), - jnp.logical_and(turn == 2, jnp.logical_not(prefer_y)), - ) + return self._is_on_line_for_position(state.player_car.position, slope, b, player_speed, turn) @partial(jax.jit, static_argnums=(0,)) def _landing_in_water(self, state: UpNDownState, new_position_x: chex.Array, new_position_y: chex.Array) -> chex.Array: @@ -214,6 +255,181 @@ def _landing_in_water(self, state: UpNDownState, new_position_x: chex.Array, new between_roads = jnp.logical_and(new_position_x > jnp.minimum(road_A_x, road_B_x), new_position_x < jnp.maximum(road_A_x, road_B_x)) return landing_in_Water, between_roads, road_A_x, road_B_x + @partial(jax.jit, static_argnums=(0,)) + def _landing_in_water_for_indices(self, road_index_A: chex.Array, road_index_B: chex.Array, new_position_x: chex.Array, new_position_y: chex.Array) -> chex.Array: + road_A_x = ((new_position_y - self.consts.TRACK_CORNERS_Y[road_index_A]) / (self.consts.TRACK_CORNERS_Y[road_index_A+1] - self.consts.TRACK_CORNERS_Y[road_index_A])) * (self.consts.FIRST_TRACK_CORNERS_X[road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_X[road_index_A]) + self.consts.FIRST_TRACK_CORNERS_X[road_index_A] + road_B_x = ((new_position_y - self.consts.TRACK_CORNERS_Y[road_index_B]) / (self.consts.TRACK_CORNERS_Y[road_index_B+1] - self.consts.TRACK_CORNERS_Y[road_index_B])) * (self.consts.SECOND_TRACK_CORNERS_X[road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[road_index_B]) + self.consts.SECOND_TRACK_CORNERS_X[road_index_B] + distance_to_road_A = jnp.abs(new_position_x - road_A_x) + distance_to_road_B = jnp.abs(new_position_x - road_B_x) + landing_in_water = jnp.logical_and(distance_to_road_A > self.consts.LANDING_ZONE, distance_to_road_B > self.consts.LANDING_ZONE) + between_roads = jnp.logical_and(new_position_x > jnp.minimum(road_A_x, road_B_x), new_position_x < jnp.maximum(road_A_x, road_B_x)) + return landing_in_water, between_roads, road_A_x, road_B_x + + @partial(jax.jit, static_argnums=(0,)) + def _advance_car_core( + self, + position_x: chex.Array, + position_y: chex.Array, + road_index_A: chex.Array, + road_index_B: chex.Array, + current_road: chex.Array, + speed: chex.Array, + is_jumping: chex.Array, + is_on_road: chex.Array, + step_counter: chex.Array, + width: chex.Array, + height: chex.Array, + car_type: chex.Array, + landing_check: chex.Array, + ) -> Car: + dividers = jnp.array([0, 1, 2, 4, 8, 16]) + abs_speed = jnp.abs(speed) + speed_divider = dividers[abs_speed] + effective_divider = jnp.maximum(1, speed_divider) + period = jnp.maximum(1, 16 // effective_divider) + half_period = jnp.maximum(1, period // 2) + speed_sign = jax.lax.cond(speed != 0, lambda _: jax.lax.abs(speed) / speed, lambda _: jnp.array(0.0), operand=None) + + slope, b = self._get_slope_and_intercept_from_indices(current_road, road_index_A, road_index_B) + + direction_raw = jax.lax.cond( + current_road == 0, + lambda _: self.consts.FIRST_TRACK_CORNERS_X[road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_X[road_index_A], + lambda _: self.consts.SECOND_TRACK_CORNERS_X[road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[road_index_B], + operand=None, + ) + car_direction_x = jax.lax.cond(direction_raw > 0, lambda _: 1, lambda _: -1, operand=None) + + move_y = jnp.logical_and((step_counter % period) == (half_period % period), speed != 0) + move_x = jnp.logical_and((step_counter % period) == 0, speed != 0) + + position = EntityPosition(x=position_x, y=position_y, width=width, height=height) + + new_player_y = jax.lax.cond( + move_y, + lambda _: jax.lax.cond( + is_jumping, + lambda _: position_y + speed_sign * -1, + lambda _: jax.lax.cond( + self._is_on_line_for_position(position, slope, b, speed_sign, 1), + lambda _: position_y + speed_sign * -1, + lambda _: jnp.array(position_y, float), + operand=None, + ), + operand=None, + ), + lambda _: jnp.array(position_y, float), + operand=None, + ) + + new_player_x = jax.lax.cond( + move_x, + lambda _: jax.lax.cond( + is_jumping, + lambda _: position_x + speed_sign * car_direction_x, + lambda _: jax.lax.cond( + self._is_on_line_for_position(position, slope, b, speed_sign, 2), + lambda _: position_x + speed_sign * car_direction_x, + lambda _: jnp.array(position_x, float), + operand=None, + ), + operand=None, + ), + lambda _: jnp.array(position_x, float), + operand=None, + ) + + landing_in_water, between_roads, road_A_x, road_B_x = self._landing_in_water_for_indices(road_index_A, road_index_B, new_player_x, new_player_y) + landing_in_water = jnp.logical_and(landing_check, landing_in_water) + + updated_current_road = jax.lax.cond( + landing_in_water, + lambda _: 2, + lambda _: jax.lax.cond( + is_on_road, + lambda _: current_road, + lambda _: jax.lax.cond( + jnp.abs(new_player_x - road_A_x) < jnp.abs(new_player_x - road_B_x), + lambda _: 0, + lambda _: 1, + operand=None, + ), + operand=None, + ), + operand=None, + ) + + next_road_index_A = jax.lax.cond( + updated_current_road == 2, + lambda _: road_index_A, + lambda _: jax.lax.cond( + self.consts.TRACK_CORNERS_Y[road_index_A] < new_player_y, + lambda _: road_index_A - 1, + lambda _: jax.lax.cond( + len(self.consts.TRACK_CORNERS_Y) == road_index_A + 1, + lambda _: jax.lax.cond( + self.consts.TRACK_CORNERS_Y[0] > new_player_y, + lambda _: 0, + lambda _: road_index_A, + operand=None, + ), + lambda _: jax.lax.cond( + self.consts.TRACK_CORNERS_Y[road_index_A+1] > new_player_y, + lambda _: road_index_A + 1, + lambda _: road_index_A, + operand=None, + ), + operand=None, + ), + operand=None, + ), + operand=None, + ) + + next_road_index_B = jax.lax.cond( + updated_current_road == 2, + lambda _: road_index_B, + lambda _: jax.lax.cond( + self.consts.TRACK_CORNERS_Y[road_index_B] < new_player_y, + lambda _: road_index_B - 1, + lambda _: jax.lax.cond( + len(self.consts.TRACK_CORNERS_Y) == road_index_B + 1, + lambda _: jax.lax.cond( + self.consts.TRACK_CORNERS_Y[0] > new_player_y, + lambda _: 0, + lambda _: road_index_B, + operand=None, + ), + lambda _: jax.lax.cond( + self.consts.TRACK_CORNERS_Y[road_index_B+1] > new_player_y, + lambda _: road_index_B + 1, + lambda _: road_index_B, + operand=None, + ), + operand=None, + ), + operand=None, + ), + operand=None, + ) + + wrapped_y = -((new_player_y * -1) % 1036) + + return Car( + position=EntityPosition( + x=new_player_x, + y=wrapped_y, + width=width, + height=height, + ), + speed=speed, + direction_x=car_direction_x, + current_road=updated_current_road, + road_index_A=next_road_index_A, + road_index_B=next_road_index_B, + type=car_type, + ) + def _flag_step(self, state: UpNDownState, new_player_y: chex.Array, player_x: chex.Array, current_road: chex.Array) -> Tuple[Flag, chex.Array, chex.Array]: """Update flag collection state and score. @@ -484,12 +700,6 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: up = jnp.logical_or(action == Action.UP, action == Action.UPFIRE) down = jnp.logical_or(action == Action.DOWN, action == Action.DOWNFIRE) jump = jnp.logical_or(action == Action.FIRE, jnp.logical_or(action == Action.UPFIRE, action == Action.DOWNFIRE)) - lives=jnp.array(self.consts.INITIAL_LIVES, dtype=jnp.int32), - is_dead=jnp.array(False), - respawn_timer=jnp.array(0, dtype=jnp.int32), - - - player_speed = state.player_car.speed player_speed = jax.lax.cond( @@ -505,9 +715,6 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: lambda s: s, operand=player_speed, ) - dividers = jnp.array([0, 1, 2, 4, 8]) - speed_divider = dividers[jnp.abs(player_speed)] - is_jumping = jnp.logical_or(jnp.logical_and(state.is_jumping, state.jump_cooldown > 0), jnp.logical_and(state.is_on_road, jnp.logical_and(player_speed >= 0, jnp.logical_and(state.jump_cooldown == 0, jump)))) jump_cooldown = jax.lax.cond( @@ -519,173 +726,32 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: operand=None), operand=state.jump_cooldown, ) - - - - - ##check if player is on the the road - is_on_road = ~state.is_jumping - - '''direction_change = jax.lax.cond( - jnp.logical_and(is_on_road, jnp.logical_or(jnp.logical_and(jnp.equal(road_index_A, state.player_car.road_index_A) , state.player_car.current_road == 0), (jnp.logical_and(jnp.equal(road_index_B, state.player_car.road_index_B) , state.player_car.current_road == 1)))) , - lambda s: False, - lambda s: True, - operand=None, - )''' - road_index_A = state.player_car.road_index_A - road_index_B = state.player_car.road_index_B - - car_direction_x = jax.lax.cond(state.player_car.current_road == 0, - lambda s: self.consts.FIRST_TRACK_CORNERS_X[road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_X[road_index_A], - lambda s: self.consts.SECOND_TRACK_CORNERS_X[road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[road_index_B], - operand=None), - car_direction_x = jax.lax.cond( - car_direction_x[0] > 0, - lambda s: 1, - lambda s: -1, - operand=car_direction_x, - ) - + is_on_road = ~is_jumping is_landing = jnp.logical_and(state.jump_cooldown == 1, jump_cooldown == 0) - ##calculate new position with speed (TODO: calculate better speed) - player_y = jax.lax.cond( - jnp.logical_and((state.step_counter % (16/ speed_divider) == 8 / speed_divider), player_speed != 0,), - lambda s: jax.lax.cond( - is_jumping, - lambda s: state.player_car.position.y + jax.lax.abs(player_speed) / player_speed * -1, - lambda s: jax.lax.cond( - self._isOnLine(state, jax.lax.abs(player_speed) / player_speed, 1), - lambda s: s + jax.lax.abs(player_speed) / player_speed * -1, - lambda s: jnp.array(s, float), - operand=state.player_car.position.y, - ), - operand=state.player_car.position.y), - lambda s: jnp.array(s, float), - operand=state.player_car.position.y, - ) - player_x = jax.lax.cond( - jnp.logical_and((state.step_counter % (16/ speed_divider) == 0), player_speed != 0,), - lambda s: jax.lax.cond( - is_jumping, - lambda s: s + jax.lax.abs(player_speed) / player_speed * car_direction_x, - lambda s: jax.lax.cond( - self._isOnLine(state, jax.lax.abs(player_speed) / player_speed, 2), - lambda s: s + jax.lax.abs(player_speed) / player_speed * car_direction_x, - lambda s: jnp.array(s, float), - operand=state.player_car.position.x, - ), - operand=state.player_car.position.x), - lambda s: jnp.array(s, float), - operand=state.player_car.position.x, - ) - - ##if y not on mx +b then no move - - - landing_in_Water, between_roads, road_A_x, road_B_x = self._landing_in_water(state, player_x, player_y) - landing_in_Water = jnp.logical_and(is_landing, landing_in_Water) - - - current_road = jax.lax.cond( - landing_in_Water, - lambda s: 2, - lambda s: jax.lax.cond( - is_on_road, - lambda s: state.player_car.current_road, - lambda s: jax.lax.cond( - jnp.abs(player_x - road_A_x) < jnp.abs(player_x - road_B_x), - lambda s: 0, - lambda s: 1, - operand=None, - ), - operand=None, - ), - operand=None, - ) - - road_index_A = jax.lax.cond( - current_road == 2, - lambda s: road_index_A, - lambda s: jax.lax.cond( - self.consts.TRACK_CORNERS_Y[road_index_A] < player_y, - lambda s: road_index_A - 1, - lambda s: jax.lax.cond( - len(self.consts.TRACK_CORNERS_Y) == road_index_A + 1, - lambda s: jax.lax.cond( - self.consts.TRACK_CORNERS_Y[0] > player_y, - lambda s: 0, - lambda s: road_index_A, - operand=None, - ), - lambda s: jax.lax.cond( - self.consts.TRACK_CORNERS_Y[road_index_A+1] > player_y, - lambda s: road_index_A + 1, - lambda s: road_index_A, - operand=None, - ), - operand=None, - ), - operand=None, - ), - operand=None, - ) - - road_index_B = jax.lax.cond( - current_road == 2, - lambda s: road_index_B, - lambda s: jax.lax.cond( - self.consts.TRACK_CORNERS_Y[road_index_B] < player_y, - lambda s: road_index_B - 1, - lambda s: jax.lax.cond( - len(self.consts.TRACK_CORNERS_Y) == road_index_B + 1, - lambda s: jax.lax.cond( - self.consts.TRACK_CORNERS_Y[0] > player_y, - lambda s: 0, - lambda s: road_index_B, - operand=None, - ), - lambda s: jax.lax.cond( - self.consts.TRACK_CORNERS_Y[road_index_B+1] > player_y, - lambda s: road_index_B + 1, - lambda s: road_index_B, - operand=None, - ), - operand=None, - ), - operand=None, - ), - operand=None, + updated_player_car = self._advance_car_core( + position_x=state.player_car.position.x, + position_y=state.player_car.position.y, + road_index_A=state.player_car.road_index_A, + road_index_B=state.player_car.road_index_B, + current_road=state.player_car.current_road, + speed=player_speed, + is_jumping=is_jumping, + is_on_road=is_on_road, + step_counter=state.step_counter, + width=state.player_car.position.width, + height=state.player_car.position.height, + car_type=state.player_car.type, + landing_check=is_landing, ) - jax.debug.print("Player X: {}, Player Y: {}, car_direction_x: {}", player_x, player_y, car_direction_x) - - # Calculate new player y position after wrapping - new_player_y = -((player_y * -1) % 1036) - return UpNDownState( score=state.score, difficulty=state.difficulty, jump_cooldown=jump_cooldown, is_jumping=is_jumping, is_on_road=is_on_road, - player_car=Car( - position=EntityPosition( - x=player_x, - y=new_player_y, - width=state.player_car.position.width, - height=state.player_car.position.height, - ), - speed=player_speed, - direction_x=car_direction_x, - current_road=current_road, - road_index_A=road_index_A, - road_index_B=road_index_B, - type=state.player_car.type, - ), - lives=state.lives, - is_dead=state.is_dead, - respawn_timer=state.respawn_timer, + player_car=updated_player_car, step_counter=state.step_counter + 1, round_started=jnp.logical_or(state.round_started, player_speed != 0), movement_steps=jax.lax.cond( @@ -698,6 +764,8 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: flags_collected_mask=state.flags_collected_mask, collectibles=state.collectibles, collectible_spawn_timer=state.collectible_spawn_timer, + enemy_cars=state.enemy_cars, + enemy_spawn_timer=state.enemy_spawn_timer, ) def _flag_step_main(self, state: UpNDownState) -> UpNDownState: @@ -727,19 +795,9 @@ def _flag_step_main(self, state: UpNDownState) -> UpNDownState: flags_collected_mask=new_flags_collected_mask, collectibles=state.collectibles, collectible_spawn_timer=state.collectible_spawn_timer, + enemy_cars=state.enemy_cars, + enemy_spawn_timer=state.enemy_spawn_timer, ) - def _completion_bonus_step(self, state: UpNDownState) -> UpNDownState: - all_flags_collected = jnp.all(state.flags_collected_mask) - - bonus = jax.lax.cond( - all_flags_collected, - lambda _: self.consts.ALL_FLAGS_BONUS, - lambda _: 0, - operand=None, - ) - return state._replace(score=state.score + bonus,lives=state.lives, - is_dead=state.is_dead, - respawn_timer=state.respawn_timer,) def _collectible_step_main(self, state: UpNDownState) -> UpNDownState: """Update collectible spawning, despawning, and collection.""" @@ -768,6 +826,159 @@ def _collectible_step_main(self, state: UpNDownState) -> UpNDownState: flags_collected_mask=state.flags_collected_mask, collectibles=updated_collectibles, collectible_spawn_timer=new_collectible_timer, + enemy_cars=state.enemy_cars, + enemy_spawn_timer=state.enemy_spawn_timer, + ) + + @partial(jax.jit, static_argnums=(0,)) + def _enemy_step_main(self, state: UpNDownState) -> UpNDownState: + """Spawn and move enemy cars that share the player's road logic.""" + base_key = jax.random.PRNGKey(2025) + step_key = jax.random.fold_in(base_key, state.step_counter) + key_spawn_offset, key_spawn_side, key_spawn_speed, key_spawn_direction, key_spawn_type, key_spawn_sign, key_flip_root = jax.random.split(step_key, 7) + + active_mask = state.enemy_cars.active + active_count = jnp.sum(active_mask.astype(jnp.int32)) + can_spawn = active_count < self.consts.MAX_ENEMY_CARS + + spawn_timer = jax.lax.cond( + jnp.logical_and(state.enemy_spawn_timer <= 0, can_spawn), + lambda _: self.consts.ENEMY_SPAWN_INTERVAL, + lambda _: state.enemy_spawn_timer - 1, + operand=None, + ) + should_spawn = jnp.logical_and(state.enemy_spawn_timer <= 0, can_spawn) + + inactive_mask = jnp.logical_not(active_mask) + first_inactive = jnp.argmax(inactive_mask.astype(jnp.int32)) + has_inactive = jnp.any(inactive_mask) + spawn_idx = jax.lax.cond(has_inactive, lambda _: first_inactive, lambda _: jnp.array(0, dtype=jnp.int32), operand=None) + spawn_mask = (jnp.arange(self.consts.MAX_ENEMY_CARS) == spawn_idx) & should_spawn & has_inactive + + spawn_offset = self.consts.ENEMY_OFFSCREEN_SPAWN_OFFSET + active_count * self.consts.ENEMY_MIN_SPAWN_GAP + jax.random.uniform(key_spawn_offset, minval=0.0, maxval=40.0) + spawn_side = jax.random.choice(key_spawn_side, jnp.array([-1.0, 1.0])) + raw_spawn_y = state.player_car.position.y + spawn_side * spawn_offset + spawn_y = -(((raw_spawn_y) * -1) % 1036) + spawn_road = jax.random.randint(key_spawn_direction, shape=(), minval=0, maxval=2) + + def get_road_segment(y): + segments = jnp.sum(self.consts.TRACK_CORNERS_Y > y) + return jnp.clip(segments - 1, 0, len(self.consts.TRACK_CORNERS_Y) - 2) + + segment_spawn = get_road_segment(spawn_y) + spawn_x = jax.lax.cond( + spawn_road == 0, + lambda _: self._get_x_on_road(spawn_y, segment_spawn, self.consts.FIRST_TRACK_CORNERS_X), + lambda _: self._get_x_on_road(spawn_y, segment_spawn, self.consts.SECOND_TRACK_CORNERS_X), + operand=None, + ) + + spawn_speed_mag = jax.random.randint(key_spawn_speed, shape=(), minval=self.consts.ENEMY_SPEED_MIN, maxval=self.consts.ENEMY_SPEED_MAX + 1) + spawn_speed_sign = jax.random.choice(key_spawn_sign, jnp.array([-1, 1])) + spawn_speed = spawn_speed_mag * spawn_speed_sign + spawn_type = jax.random.randint(key_spawn_type, shape=(), minval=0, maxval=4) + + direction_raw = jax.lax.cond( + spawn_road == 0, + lambda _: self.consts.FIRST_TRACK_CORNERS_X[segment_spawn+1] - self.consts.FIRST_TRACK_CORNERS_X[segment_spawn], + lambda _: self.consts.SECOND_TRACK_CORNERS_X[segment_spawn+1] - self.consts.SECOND_TRACK_CORNERS_X[segment_spawn], + operand=None, + ) + spawn_direction_x = jax.lax.cond(direction_raw > 0, lambda _: 1, lambda _: -1, operand=None) + + enemy_position_x = jnp.where(spawn_mask, spawn_x, state.enemy_cars.position.x) + enemy_position_y = jnp.where(spawn_mask, spawn_y, state.enemy_cars.position.y) + enemy_width = state.enemy_cars.position.width + enemy_height = state.enemy_cars.position.height + enemy_speed = jnp.where(spawn_mask, spawn_speed, state.enemy_cars.speed) + enemy_type = jnp.where(spawn_mask, spawn_type, state.enemy_cars.type) + enemy_current_road = jnp.where(spawn_mask, spawn_road, state.enemy_cars.current_road) + enemy_road_index_A = jnp.where(spawn_mask, segment_spawn, state.enemy_cars.road_index_A) + enemy_road_index_B = jnp.where(spawn_mask, segment_spawn, state.enemy_cars.road_index_B) + enemy_direction_x = jnp.where(spawn_mask, spawn_direction_x, state.enemy_cars.direction_x) + enemy_active = jnp.where(spawn_mask, True, state.enemy_cars.active) + enemy_age = jnp.where(spawn_mask, jnp.zeros_like(state.enemy_cars.age), state.enemy_cars.age) + + flip_keys = jax.random.split(key_flip_root, self.consts.MAX_ENEMY_CARS) + flip_mask = jax.vmap(lambda k: jax.random.uniform(k) < self.consts.ENEMY_DIRECTION_SWITCH_PROB)(flip_keys) + enemy_speed = jnp.where(jnp.logical_and(enemy_active, flip_mask), -enemy_speed, enemy_speed) + + move_fn = lambda px, py, ra, rb, cr, sp, tp: self._advance_car_core( + position_x=px, + position_y=py, + road_index_A=ra, + road_index_B=rb, + current_road=cr, + speed=sp, + is_jumping=False, + is_on_road=True, + step_counter=state.step_counter, + width=self.consts.PLAYER_SIZE[0], + height=self.consts.PLAYER_SIZE[1], + car_type=tp, + landing_check=False, + ) + + advanced_cars = jax.vmap(move_fn)( + enemy_position_x, + enemy_position_y, + enemy_road_index_A, + enemy_road_index_B, + enemy_current_road, + enemy_speed, + enemy_type, + ) + + moved_position_x = jnp.where(enemy_active, advanced_cars.position.x, enemy_position_x) + moved_position_y = jnp.where(enemy_active, advanced_cars.position.y, enemy_position_y) + moved_road_index_A = jnp.where(enemy_active, advanced_cars.road_index_A, enemy_road_index_A) + moved_road_index_B = jnp.where(enemy_active, advanced_cars.road_index_B, enemy_road_index_B) + moved_current_road = jnp.where(enemy_active, advanced_cars.current_road, enemy_current_road) + moved_direction_x = jnp.where(enemy_active, advanced_cars.direction_x, enemy_direction_x) + + enemy_age = jnp.where(enemy_active, enemy_age + 1, enemy_age) + + delta_y = moved_position_y - state.player_car.position.y + wrapped_dist = jnp.minimum(jnp.abs(delta_y), 1036 - jnp.abs(delta_y)) + far_mask = wrapped_dist > self.consts.ENEMY_DESPAWN_DISTANCE + age_mask = enemy_age > self.consts.ENEMY_MAX_AGE + despawn_mask = jnp.logical_and(enemy_active, jnp.logical_or(far_mask, age_mask)) + final_active = jnp.logical_and(enemy_active, jnp.logical_not(despawn_mask)) + enemy_age = jnp.where(despawn_mask, jnp.zeros_like(enemy_age), enemy_age) + + next_enemy_cars = EnemyCars( + position=EntityPosition( + x=moved_position_x, + y=moved_position_y, + width=enemy_width, + height=enemy_height, + ), + speed=enemy_speed, + type=enemy_type, + current_road=moved_current_road, + road_index_A=moved_road_index_A, + road_index_B=moved_road_index_B, + direction_x=moved_direction_x, + active=final_active, + age=enemy_age, + ) + + return UpNDownState( + score=state.score, + difficulty=state.difficulty, + jump_cooldown=state.jump_cooldown, + is_jumping=state.is_jumping, + is_on_road=state.is_on_road, + player_car=state.player_car, + step_counter=state.step_counter, + round_started=state.round_started, + movement_steps=state.movement_steps, + flags=state.flags, + flags_collected_mask=state.flags_collected_mask, + collectibles=state.collectibles, + collectible_spawn_timer=state.collectible_spawn_timer, + enemy_cars=next_enemy_cars, + enemy_spawn_timer=spawn_timer, ) def _passive_score_step_main(self, state: UpNDownState) -> UpNDownState: @@ -796,6 +1007,8 @@ def _passive_score_step_main(self, state: UpNDownState) -> UpNDownState: flags_collected_mask=state.flags_collected_mask, collectibles=state.collectibles, collectible_spawn_timer=state.collectible_spawn_timer, + enemy_cars=state.enemy_cars, + enemy_spawn_timer=state.enemy_spawn_timer, ) @@ -845,45 +1058,90 @@ def get_road_segment(y): type_id=jnp.zeros(self.consts.MAX_COLLECTIBLES, dtype=jnp.int32), active=jnp.zeros(self.consts.MAX_COLLECTIBLES, dtype=jnp.bool_), ) - player_car = Car( - position=EntityPosition( - x=jnp.array(30, dtype=jnp.int32), - y=jnp.array(0, dtype=jnp.int32), - width=jnp.array(self.consts.PLAYER_SIZE[0], dtype=jnp.int32), - height=jnp.array(self.consts.PLAYER_SIZE[1], dtype=jnp.int32), - ), - speed=jnp.array(0, dtype=jnp.int32), - direction_x=jnp.array(0, dtype=jnp.int32), - current_road=jnp.array(0, dtype=jnp.int32), - road_index_A=jnp.array(0, dtype=jnp.int32), - road_index_B=jnp.array(0, dtype=jnp.int32), - type=jnp.array(0, dtype=jnp.int32), - ) - state = UpNDownState( - score=jnp.array(0, dtype=jnp.int32), - difficulty=jnp.array(self.consts.DIFFICULTIES[0], dtype=jnp.int32), - jump_cooldown=jnp.array(0, dtype=jnp.int32), - is_jumping=jnp.array(False), - is_on_road=jnp.array(True), - - player_car=player_car, - step_counter=jnp.array(0, dtype=jnp.int32), - round_started=jnp.array(False), - movement_steps=jnp.array(0, dtype=jnp.int32), + def get_road_segment(y): + segments = jnp.sum(self.consts.TRACK_CORNERS_Y > y) + return jnp.clip(segments - 1, 0, len(self.consts.TRACK_CORNERS_Y) - 2) - flags=flags, - flags_collected_mask=jnp.zeros(self.consts.NUM_FLAGS, dtype=jnp.bool_), - collectibles=collectibles, - collectible_spawn_timer=jnp.array(0, dtype=jnp.int32), + # Seed initial visible enemies spaced around the player + key_init, key_type, key_road, key_speed, key_sign = jax.random.split(key, 5) + player_start_y = 0.0 + offsets = self.consts.INITIAL_ENEMY_BASE_OFFSET + self.consts.INITIAL_ENEMY_GAP * jnp.arange(self.consts.INITIAL_ENEMY_COUNT) + spawn_signs = jax.random.choice(key_sign, jnp.array([-1.0, 1.0]), shape=(self.consts.INITIAL_ENEMY_COUNT,)) + raw_spawn_y = player_start_y + spawn_signs * offsets + init_y = -(((raw_spawn_y) * -1) % 1036) + init_road = jax.random.randint(key_road, shape=(self.consts.INITIAL_ENEMY_COUNT,), minval=0, maxval=2) + init_segments = jax.vmap(get_road_segment)(init_y) + init_x = jax.vmap(lambda y, seg, road: jax.lax.cond( + road == 0, + lambda _: self._get_x_on_road(y, seg, self.consts.FIRST_TRACK_CORNERS_X), + lambda _: self._get_x_on_road(y, seg, self.consts.SECOND_TRACK_CORNERS_X), + operand=None, + ))(init_y, init_segments, init_road) + init_type = jax.random.randint(key_type, shape=(self.consts.INITIAL_ENEMY_COUNT,), minval=0, maxval=4) + init_speed_mag = jax.random.randint(key_speed, shape=(self.consts.INITIAL_ENEMY_COUNT,), minval=self.consts.ENEMY_SPEED_MIN, maxval=self.consts.ENEMY_SPEED_MAX + 1) + init_speed_sign = jax.random.choice(key_init, jnp.array([-1, 1]), shape=(self.consts.INITIAL_ENEMY_COUNT,)) + init_speed = init_speed_mag * init_speed_sign + + def init_direction(seg, road): + raw = jax.lax.cond( + road == 0, + lambda _: self.consts.FIRST_TRACK_CORNERS_X[seg+1] - self.consts.FIRST_TRACK_CORNERS_X[seg], + lambda _: self.consts.SECOND_TRACK_CORNERS_X[seg+1] - self.consts.SECOND_TRACK_CORNERS_X[seg], + operand=None, + ) + return jax.lax.cond(raw > 0, lambda _: 1, lambda _: -1, operand=None) - # -------- NEW REQUIRED FIELDS -------- - lives=jnp.array(self.consts.INITIAL_LIVES, dtype=jnp.int32), - is_dead=jnp.array(False), - respawn_timer=jnp.array(0, dtype=jnp.int32), - ) + init_dir = jax.vmap(init_direction)(init_segments, init_road) + pad = self.consts.MAX_ENEMY_CARS - self.consts.INITIAL_ENEMY_COUNT + enemy_cars = EnemyCars( + position=EntityPosition( + x=jnp.concatenate([init_x, jnp.zeros(pad, dtype=jnp.float32)]), + y=jnp.concatenate([init_y, jnp.zeros(pad, dtype=jnp.float32)]), + width=jnp.full((self.consts.MAX_ENEMY_CARS,), self.consts.PLAYER_SIZE[0]), + height=jnp.full((self.consts.MAX_ENEMY_CARS,), self.consts.PLAYER_SIZE[1]), + ), + speed=jnp.concatenate([init_speed, jnp.zeros(pad, dtype=jnp.int32)]), + type=jnp.concatenate([init_type, jnp.zeros(pad, dtype=jnp.int32)]), + current_road=jnp.concatenate([init_road, jnp.zeros(pad, dtype=jnp.int32)]), + road_index_A=jnp.concatenate([init_segments, jnp.zeros(pad, dtype=jnp.int32)]), + road_index_B=jnp.concatenate([init_segments, jnp.zeros(pad, dtype=jnp.int32)]), + direction_x=jnp.concatenate([init_dir, jnp.zeros(pad, dtype=jnp.int32)]), + active=jnp.concatenate([jnp.ones(self.consts.INITIAL_ENEMY_COUNT, dtype=jnp.bool_), jnp.zeros(pad, dtype=jnp.bool_)]), + age=jnp.concatenate([jnp.zeros(self.consts.INITIAL_ENEMY_COUNT, dtype=jnp.int32), jnp.zeros(pad, dtype=jnp.int32)]), + ) + state = UpNDownState( + score=0, + difficulty=self.consts.DIFFICULTIES[0], + jump_cooldown=0, + is_jumping=False, + is_on_road=True, + player_car=Car( + position=EntityPosition( + x=30, + y= 0, + width=self.consts.PLAYER_SIZE[0], + height=self.consts.PLAYER_SIZE[1], + ), + speed=0, + direction_x=0, + current_road=0, + road_index_A=0, + road_index_B=0, + type=0, + ), + step_counter=jnp.array(0), + round_started=jnp.array(False), + movement_steps=jnp.array(0), + flags=flags, + flags_collected_mask=jnp.zeros(self.consts.NUM_FLAGS, dtype=jnp.bool_), + collectibles=collectibles, + collectible_spawn_timer=jnp.array(0, dtype=jnp.int32), + enemy_cars=enemy_cars, + enemy_spawn_timer=jnp.array(self.consts.ENEMY_SPAWN_INTERVAL, dtype=jnp.int32), + ) initial_obs = self._get_observation(state) return initial_obs, state @@ -897,8 +1155,7 @@ def step(self, state: UpNDownState, action: chex.Array) -> Tuple[UpNDownObservat state = self._flag_step_main(state) state = self._completion_bonus_step(state) state = self._collectible_step_main(state) - - + state = self._enemy_step_main(state) done = self._get_done(state) env_reward = self._get_reward(previous_state, state) @@ -1006,6 +1263,37 @@ def __init__(self, consts: UpNDownConstants = None): repeats = max(1, math.ceil(self.view_height / road_cycle) + 2) self._road_tile_offsets = jnp.arange(-repeats, repeats + 1, dtype=jnp.int32) * jnp.int32(self.complete_road_size) self._num_road_tiles = int(self._road_tile_offsets.shape[0]) + + self.enemy_sprite_names = { + self.consts.ENEMY_TYPE_CAMERO: ("camero_left", "camero_right"), + self.consts.ENEMY_TYPE_FLAG_CARRIER: ("flag_carrier_left", "flag_carrier_right"), + self.consts.ENEMY_TYPE_PICKUP: ("pick_up_truck_left", "pick_up_truck_right"), + self.consts.ENEMY_TYPE_TRUCK: ("truck_left", "truck_right"), + } + + # Pre-pad enemy masks to a common shape so switch/array indexing works under jit + enemy_left_raw = [ + self.SHAPE_MASKS["camero_left"], + self.SHAPE_MASKS["flag_carrier_left"], + self.SHAPE_MASKS["pick_up_truck_left"], + self.SHAPE_MASKS["truck_left"], + ] + enemy_right_raw = [ + self.SHAPE_MASKS["camero_right"], + self.SHAPE_MASKS["flag_carrier_right"], + self.SHAPE_MASKS["pick_up_truck_right"], + self.SHAPE_MASKS["truck_right"], + ] + max_h = max([m.shape[0] for m in enemy_left_raw + enemy_right_raw]) + max_w = max([m.shape[1] for m in enemy_left_raw + enemy_right_raw]) + + def _pad_mask(mask): + pad_h = max_h - mask.shape[0] + pad_w = max_w - mask.shape[1] + return jnp.pad(mask, ((0, pad_h), (0, pad_w)), constant_values=self.jr.TRANSPARENT_ID) + + self.enemy_left_masks = jnp.stack([_pad_mask(m) for m in enemy_left_raw], axis=0) + self.enemy_right_masks = jnp.stack([_pad_mask(m) for m in enemy_right_raw], axis=0) # Precompute flag mask data for recoloring without special-casing pink self.flag_base_mask = self.SHAPE_MASKS["pink_flag"] @@ -1081,6 +1369,14 @@ def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.n {'name': 'background', 'type': 'background', 'data': backgroundSprite}, {'name': 'road', 'type': 'group', 'files': roads}, {'name': 'player', 'type': 'single', 'file': 'player_car.npy'}, + {'name': 'camero_left', 'type': 'single', 'file': 'enemy_cars/camero_left.npy'}, + {'name': 'camero_right', 'type': 'single', 'file': 'enemy_cars/camero_right.npy'}, + {'name': 'flag_carrier_left', 'type': 'single', 'file': 'enemy_cars/flag_carrier_left.npy'}, + {'name': 'flag_carrier_right', 'type': 'single', 'file': 'enemy_cars/flag_carrier_right.npy'}, + {'name': 'pick_up_truck_left', 'type': 'single', 'file': 'enemy_cars/pick_up_truck_left.npy'}, + {'name': 'pick_up_truck_right', 'type': 'single', 'file': 'enemy_cars/pick_up_truck_right.npy'}, + {'name': 'truck_left', 'type': 'single', 'file': 'enemy_cars/truck_left.npy'}, + {'name': 'truck_right', 'type': 'single', 'file': 'enemy_cars/truck_right.npy'}, {'name': 'wall_top', 'type': 'procedural', 'data': topBlockSprite}, {'name': 'wall_bottom', 'type': 'procedural', 'data': bottomBlockSprite}, {'name': 'all_flags_top', 'type': 'single', 'file': 'all_flags_top.npy'}, @@ -1146,6 +1442,32 @@ def combine(i, acc): raster = jax.lax.fori_loop(0, total_segments, combine, raster) + def select_enemy_mask(enemy_type: chex.Array, going_left: chex.Array): + left_mask = self.enemy_left_masks[enemy_type] + right_mask = self.enemy_right_masks[enemy_type] + return jax.lax.cond(going_left, lambda _: left_mask, lambda _: right_mask, operand=None) + + def render_enemy(carry, enemy_idx): + raster = carry + enemy_active = state.enemy_cars.active[enemy_idx] + enemy_x = state.enemy_cars.position.x[enemy_idx] + enemy_y = state.enemy_cars.position.y[enemy_idx] + enemy_type = state.enemy_cars.type[enemy_idx] + direction_x = state.enemy_cars.direction_x[enemy_idx] + screen_y = 105 + (enemy_y - state.player_car.position.y) + is_visible = jnp.logical_and(enemy_active, jnp.logical_and(screen_y > 25, screen_y < 195)) + enemy_mask = select_enemy_mask(enemy_type, direction_x < 0) + + raster = jax.lax.cond( + is_visible, + lambda r: self.jr.render_at(r, enemy_x.astype(jnp.int32), screen_y.astype(jnp.int32), enemy_mask), + lambda r: r, + operand=raster, + ) + return raster, None + + raster, _ = jax.lax.scan(render_enemy, raster, jnp.arange(self.consts.MAX_ENEMY_CARS)) + player_mask = self.SHAPE_MASKS["player"] raster = self.jr.render_at(raster, state.player_car.position.x, 105, player_mask) From c22cd6837d76d16a83614e3a8bf3bf1378987e1f Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sat, 20 Dec 2025 21:16:54 +0100 Subject: [PATCH 56/76] add jumping, crossing roads to game --- src/jaxatari/games/jax_upndown.py | 1343 +++++++++++++++++++---------- 1 file changed, 880 insertions(+), 463 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 648859a53..2d8de2915 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -3,6 +3,7 @@ import math from functools import partial from typing import NamedTuple, Tuple +import jax import jax.lax import jax.numpy as jnp import chex @@ -16,31 +17,37 @@ class UpNDownConstants(NamedTuple): FRAME_SKIP: int = 4 DIFFICULTIES: chex.Array = jnp.array([0, 1, 2, 3, 4, 5]) ACTION_REPEAT_PROBS: float = 0.25 - MAX_SPEED: int = 5 + MAX_SPEED: int = 6 + INITIAL_LIVES: int = 5 + RESPAWN_HIDE_FRAMES: int = 8 + JUMP_ARC_HEIGHT: float = 18.0 # Enemy spawning and movement - MAX_ENEMY_CARS: int = 6 + MAX_ENEMY_CARS: int = 8 ENEMY_SPAWN_INTERVAL: int = 80 ENEMY_DESPAWN_DISTANCE: int = 300 - ENEMY_SPEED_MIN: int = 2 + ENEMY_SPEED_MIN: int = 3 ENEMY_SPEED_MAX: int = 5 - ENEMY_DIRECTION_SWITCH_PROB: float = 0.005 + ENEMY_DIRECTION_SWITCH_PROB: float = 0.0001 ENEMY_OFFSCREEN_SPAWN_OFFSET: float = 100.0 - ENEMY_MIN_SPAWN_GAP: float = 40.0 - ENEMY_MAX_AGE: int = 900 - INITIAL_ENEMY_COUNT: int = 3 + ENEMY_MIN_SPAWN_GAP: float = 30.0 + ENEMY_MAX_AGE: int = 1900 + INITIAL_ENEMY_COUNT: int = 4 INITIAL_ENEMY_BASE_OFFSET: float = 40.0 - INITIAL_ENEMY_GAP: float = 50.0 + INITIAL_ENEMY_GAP: float = 30.0 ENEMY_TYPE_CAMERO: int = 0 ENEMY_TYPE_FLAG_CARRIER: int = 1 ENEMY_TYPE_PICKUP: int = 2 ENEMY_TYPE_TRUCK: int = 3 - JUMP_FRAMES: int = 10 - ALL_FLAGS_BONUS: int = 1000 - LANDING_ZONE: int = 15 - FIRST_ROAD_LENGTH: int = 4 - SECOND_ROAD_LENGTH: int = 4 + JUMP_FRAMES: int = 28 + POST_JUMP_DELAY: int = 10 + LANDING_TOLERANCE: int = 15 # Pixels tolerance for landing on a road (increased by 5 for off-road landings) + LATE_JUMP_COLLISION_FRAMES: int = 2 + LANDING_COLLISION_DISTANCE: float = 8.0 # Larger collision distance when landing (for crossings) + LATE_JUMP_ENEMY_SCORE: int = 400 + STEEP_ROAD_SPEED_REDUCTION_INTERVAL: int = 8 # Frames between each speed reduction on steep roads + TRACK_LENGTH: int = 1036 FIRST_TRACK_CORNERS_X: chex.Array = jnp.array([30, 75, 128, 75, 21, 75, 131, 111, 150, 95, 150, 115, 150, 108, 150, 115, 115, 75, 18, 38, 67, 38, 38, 20, 64, 30]) - TRACK_CORNERS_Y: chex.Array = jnp.array([0, -40, -98, -155, -203, -268, -327, -347, -382, -467, -525, -565, -597, -625, -670, -705, -738, -788, -838, -862, -898, -925, -950, -972, -1000, -1036]) + TRACK_CORNERS_Y: chex.Array = jnp.array([0, -40, -98, -155, -203, -268, -327, -347, -382, -467, -525, -565, -597, -625, -670, -705, -738, -788, -838, -862, -898, -925, -950, -972, -1000, -1035]) SECOND_TRACK_CORNERS_X: chex.Array = jnp.array([115, 75, 20, 75, 133, 75, 22, 37, 63, 27, 66, 30, 63, 24, 60, 38, 38, 75, 131, 111, 150, 118, 118, 98, 150, 115]) PLAYER_SIZE: Tuple[int, int] = (4, 16) INITIAL_ROAD_POS_Y: int = 25 @@ -60,7 +67,7 @@ class UpNDownConstants(NamedTuple): [78, 50, 181, 255], # Blue ]) # Top display positions for each flag (x coordinates where blackout squares appear) - FLAG_TOP_X_POSITIONS: chex.Array = jnp.array([13, 30, 47, 64, 82, 98, 118, 132]) + FLAG_TOP_X_POSITIONS: chex.Array = jnp.array([13, 30, 47, 64, 82, 98, 118, 134]) FLAG_TOP_Y: int = 20 FLAG_BLACKOUT_SIZE: Tuple[int, int] = (14, 14) # Size of blackout square FLAG_COLLECTION_SCORE: int = 75 # Points awarded for collecting a flag @@ -79,7 +86,7 @@ class UpNDownConstants(NamedTuple): COLLECTIBLE_TYPE_LOLLYPOP: int = 2 COLLECTIBLE_TYPE_ICE_CREAM: int = 3 # Collectible type spawn probabilities (must sum to 100) - COLLECTIBLE_SPAWN_PROBABILITIES: chex.Array = jnp.array([40, 20, 20, 20], dtype=jnp.int32) # Cherry: 40%, Balloon: 20%, Lollypop: 20%, IceCream: 20% + COLLECTIBLE_SPAWN_PROBABILITIES: chex.Array = jnp.array([35, 30, 25, 10], dtype=jnp.int32) # Cherry: 40%, Balloon: 20%, Lollypop: 20%, IceCream: 20% # Collectible type scores COLLECTIBLE_SCORES: chex.Array = jnp.array([50, 65, 70, 75], dtype=jnp.int32) # [cherry, balloon, lollypop, ice_cream] # Shared collectible colors @@ -139,8 +146,11 @@ class EnemyCars(NamedTuple): class UpNDownState(NamedTuple): score: chex.Array + lives: chex.Array + respawn_cooldown: chex.Array difficulty: chex.Array jump_cooldown: chex.Array + post_jump_cooldown: chex.Array is_jumping: chex.Array is_on_road: chex.Array player_car: Car @@ -150,6 +160,8 @@ class UpNDownState(NamedTuple): step_counter: chex.Array round_started: chex.Array movement_steps: chex.Array + steep_road_timer: chex.Array # Timer for steep road speed reduction + jump_slope: chex.Array # X movement per Y step, locked at jump start (float) # Flag state - tracks all 8 flags flags: Flag # Contains arrays of size NUM_FLAGS for each field flags_collected_mask: chex.Array # Boolean mask of which flag colors have been collected (size NUM_FLAGS) @@ -239,6 +251,32 @@ def _get_x_on_road(self, y: chex.Array, road_segment: chex.Array, track_corners_ operand=None, ) return x1 + t * (x2 - x1) + + @partial(jax.jit, static_argnums=(0,)) + def _get_road_segment(self, y: chex.Array) -> chex.Array: + """Return the road segment index for a given y position.""" + segments = jnp.sum(self.consts.TRACK_CORNERS_Y > y, dtype=jnp.int32) + max_idx = jnp.int32(len(self.consts.TRACK_CORNERS_Y) - 1) + return jnp.clip(segments - 1, 0, max_idx) + + @partial(jax.jit, static_argnums=(0,)) + def _is_steep_road_segment(self, current_road: chex.Array, road_index_A: chex.Array, road_index_B: chex.Array) -> chex.Array: + """Check if the current road segment is steep (no X direction change). + + A steep segment is one where the X coordinates of consecutive corners are the same, + meaning the road goes straight up/down with no horizontal movement. + + Returns True if the segment is steep (requires jump to pass when going up). + """ + # Get the X difference for the current road segment + x_diff = jax.lax.cond( + current_road == 0, + lambda _: jnp.abs(self.consts.FIRST_TRACK_CORNERS_X[road_index_A + 1] - self.consts.FIRST_TRACK_CORNERS_X[road_index_A]), + lambda _: jnp.abs(self.consts.SECOND_TRACK_CORNERS_X[road_index_B + 1] - self.consts.SECOND_TRACK_CORNERS_X[road_index_B]), + operand=None, + ) + # A segment is steep if there's no X change (or very small change) + return x_diff < 1.0 @partial(jax.jit, static_argnums=(0,)) def _isOnLine(self, state: UpNDownState, player_speed: chex.Array, turn: chex.Array) -> chex.Array: @@ -251,7 +289,7 @@ def _landing_in_water(self, state: UpNDownState, new_position_x: chex.Array, new road_B_x = ((new_position_y - self.consts.TRACK_CORNERS_Y[state.player_car.road_index_B]) / (self.consts.TRACK_CORNERS_Y[state.player_car.road_index_B+1] - self.consts.TRACK_CORNERS_Y[state.player_car.road_index_B])) * (self.consts.SECOND_TRACK_CORNERS_X[state.player_car.road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[state.player_car.road_index_B]) + self.consts.SECOND_TRACK_CORNERS_X[state.player_car.road_index_B] distance_to_road_A = jnp.abs(new_position_x - road_A_x) distance_to_road_B = jnp.abs(new_position_x - road_B_x) - landing_in_Water = jnp.logical_and(distance_to_road_A > self.consts.LANDING_ZONE, distance_to_road_B > self.consts.LANDING_ZONE) + landing_in_Water = jnp.logical_and(distance_to_road_A > self.consts.LANDING_TOLERANCE, distance_to_road_B > self.consts.LANDING_TOLERANCE) between_roads = jnp.logical_and(new_position_x > jnp.minimum(road_A_x, road_B_x), new_position_x < jnp.maximum(road_A_x, road_B_x)) return landing_in_Water, between_roads, road_A_x, road_B_x @@ -261,12 +299,12 @@ def _landing_in_water_for_indices(self, road_index_A: chex.Array, road_index_B: road_B_x = ((new_position_y - self.consts.TRACK_CORNERS_Y[road_index_B]) / (self.consts.TRACK_CORNERS_Y[road_index_B+1] - self.consts.TRACK_CORNERS_Y[road_index_B])) * (self.consts.SECOND_TRACK_CORNERS_X[road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[road_index_B]) + self.consts.SECOND_TRACK_CORNERS_X[road_index_B] distance_to_road_A = jnp.abs(new_position_x - road_A_x) distance_to_road_B = jnp.abs(new_position_x - road_B_x) - landing_in_water = jnp.logical_and(distance_to_road_A > self.consts.LANDING_ZONE, distance_to_road_B > self.consts.LANDING_ZONE) + landing_in_water = jnp.logical_and(distance_to_road_A > self.consts.LANDING_TOLERANCE, distance_to_road_B > self.consts.LANDING_TOLERANCE) between_roads = jnp.logical_and(new_position_x > jnp.minimum(road_A_x, road_B_x), new_position_x < jnp.maximum(road_A_x, road_B_x)) return landing_in_water, between_roads, road_A_x, road_B_x @partial(jax.jit, static_argnums=(0,)) - def _advance_car_core( + def _advance_player_car( self, position_x: chex.Array, position_y: chex.Array, @@ -280,18 +318,33 @@ def _advance_car_core( width: chex.Array, height: chex.Array, car_type: chex.Array, - landing_check: chex.Array, + is_landing: chex.Array, + stored_jump_slope: chex.Array, ) -> Car: - dividers = jnp.array([0, 1, 2, 4, 8, 16]) + """ + Advance the player car position. + + Jump logic: + - Car jumps in the direction of the road it's on at current speed + - While jumping, car moves freely (not constrained to road) + - On landing: check if car is on/near a road or between roads + - If between roads: snap to nearest road + - If too far from both roads (outside the road area): crash (water) + """ + # Speed-based movement timing + dividers = jnp.array([0, 1, 2, 4, 8, 16, 16]) abs_speed = jnp.abs(speed) - speed_divider = dividers[abs_speed] + speed_index = jnp.minimum(abs_speed, jnp.int32(dividers.shape[0] - 1)) + speed_divider = dividers[speed_index] effective_divider = jnp.maximum(1, speed_divider) period = jnp.maximum(1, 16 // effective_divider) half_period = jnp.maximum(1, period // 2) speed_sign = jax.lax.cond(speed != 0, lambda _: jax.lax.abs(speed) / speed, lambda _: jnp.array(0.0), operand=None) + # Get slope and intercept for current road slope, b = self._get_slope_and_intercept_from_indices(current_road, road_index_A, road_index_B) + # Determine X direction based on current road segment (for normal movement) direction_raw = jax.lax.cond( current_road == 0, lambda _: self.consts.FIRST_TRACK_CORNERS_X[road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_X[road_index_A], @@ -300,19 +353,26 @@ def _advance_car_core( ) car_direction_x = jax.lax.cond(direction_raw > 0, lambda _: 1, lambda _: -1, operand=None) + # Movement timing flags move_y = jnp.logical_and((step_counter % period) == (half_period % period), speed != 0) move_x = jnp.logical_and((step_counter % period) == 0, speed != 0) + # Step size (slightly larger at max speed) + step_size = jnp.where(speed_index >= self.consts.MAX_SPEED, 1.5, 1.0) + position = EntityPosition(x=position_x, y=position_y, width=width, height=height) + # === Y MOVEMENT === + # When jumping: move freely in Y direction + # When on road: only move if allowed by road geometry new_player_y = jax.lax.cond( move_y, lambda _: jax.lax.cond( is_jumping, - lambda _: position_y + speed_sign * -1, + lambda _: position_y + speed_sign * -step_size, # Free movement while jumping lambda _: jax.lax.cond( self._is_on_line_for_position(position, slope, b, speed_sign, 1), - lambda _: position_y + speed_sign * -1, + lambda _: position_y + speed_sign * -step_size, lambda _: jnp.array(position_y, float), operand=None, ), @@ -322,14 +382,18 @@ def _advance_car_core( operand=None, ) + # === X MOVEMENT === + # When jumping: use stored_jump_slope (locked at jump start) - moves X proportionally to Y + # The slope already encodes direction (dx/dy), so multiply by Y step size and speed_sign + # When on road: only move if allowed by road geometry new_player_x = jax.lax.cond( move_x, lambda _: jax.lax.cond( is_jumping, - lambda _: position_x + speed_sign * car_direction_x, + lambda _: position_x - speed_sign * stored_jump_slope * step_size, # Slope-based movement (negated because Y decreases going forward) lambda _: jax.lax.cond( self._is_on_line_for_position(position, slope, b, speed_sign, 2), - lambda _: position_x + speed_sign * car_direction_x, + lambda _: position_x + speed_sign * car_direction_x * step_size, # Normal road movement lambda _: jnp.array(position_x, float), operand=None, ), @@ -339,44 +403,76 @@ def _advance_car_core( operand=None, ) - landing_in_water, between_roads, road_A_x, road_B_x = self._landing_in_water_for_indices(road_index_A, road_index_B, new_player_x, new_player_y) - landing_in_water = jnp.logical_and(landing_check, landing_in_water) - - updated_current_road = jax.lax.cond( - landing_in_water, - lambda _: 2, + # === LANDING LOGIC === + # Get the current road segment based on new Y position + segment = self._get_road_segment(new_player_y) + + # Calculate X positions of both roads at the new Y position + road_A_x = self._get_x_on_road(new_player_y, segment, self.consts.FIRST_TRACK_CORNERS_X) + road_B_x = self._get_x_on_road(new_player_y, segment, self.consts.SECOND_TRACK_CORNERS_X) + + # Calculate distances to each road + dist_to_road_A = jnp.abs(new_player_x - road_A_x) + dist_to_road_B = jnp.abs(new_player_x - road_B_x) + + # Check if player is close enough to either road (within tolerance) + on_road_A = dist_to_road_A <= self.consts.LANDING_TOLERANCE + on_road_B = dist_to_road_B <= self.consts.LANDING_TOLERANCE + on_any_road = jnp.logical_or(on_road_A, on_road_B) + + # Check if player is between the two roads + min_road_x = jnp.minimum(road_A_x, road_B_x) + max_road_x = jnp.maximum(road_A_x, road_B_x) + between_roads = jnp.logical_and(new_player_x > min_road_x, new_player_x < max_road_x) + + # Determine which road is closer + closer_to_A = dist_to_road_A < dist_to_road_B + nearest_road_x = jnp.where(closer_to_A, road_A_x, road_B_x) + nearest_road_id = jnp.where(closer_to_A, jnp.int32(0), jnp.int32(1)) + + # === LANDING OUTCOMES === + # Valid landing: on a road OR between roads (will snap to nearest) + valid_landing = jnp.logical_or(on_any_road, between_roads) + + # If landing and between roads but not directly on a road, snap to nearest road + should_snap = jnp.logical_and(is_landing, jnp.logical_and(between_roads, jnp.logical_not(on_any_road))) + final_player_x = jnp.where(should_snap, nearest_road_x, new_player_x) + + # Water landing (crash): landing outside the valid road area + landing_in_water = jnp.logical_and(is_landing, jnp.logical_not(valid_landing)) + + # === UPDATE ROAD STATE === + # Determine which road to assign on landing + landed_road = jax.lax.cond( + on_road_A, + lambda _: jnp.int32(0), lambda _: jax.lax.cond( - is_on_road, - lambda _: current_road, - lambda _: jax.lax.cond( - jnp.abs(new_player_x - road_A_x) < jnp.abs(new_player_x - road_B_x), - lambda _: 0, - lambda _: 1, - operand=None, - ), + on_road_B, + lambda _: jnp.int32(1), + lambda _: nearest_road_id, # Between roads - use nearest operand=None, ), operand=None, ) - - next_road_index_A = jax.lax.cond( - updated_current_road == 2, - lambda _: road_index_A, + + # Update current_road + # - If landing in water: set to 2 (water/crash marker) + # - If landing successfully: set to the landed road + # - If still jumping: keep current road (frozen during jump) + # - If on road normally: update based on position + updated_current_road = jax.lax.cond( + landing_in_water, + lambda _: jnp.int32(2), # Water crash lambda _: jax.lax.cond( - self.consts.TRACK_CORNERS_Y[road_index_A] < new_player_y, - lambda _: road_index_A - 1, + is_landing, + lambda _: landed_road, # Successfully landed lambda _: jax.lax.cond( - len(self.consts.TRACK_CORNERS_Y) == road_index_A + 1, + is_jumping, + lambda _: current_road, # Keep road frozen while jumping lambda _: jax.lax.cond( - self.consts.TRACK_CORNERS_Y[0] > new_player_y, - lambda _: 0, - lambda _: road_index_A, - operand=None, - ), - lambda _: jax.lax.cond( - self.consts.TRACK_CORNERS_Y[road_index_A+1] > new_player_y, - lambda _: road_index_A + 1, - lambda _: road_index_A, + current_road == 2, + lambda _: nearest_road_id, # Recover from water state + lambda _: current_road, # Normal on-road movement operand=None, ), operand=None, @@ -385,46 +481,135 @@ def _advance_car_core( ), operand=None, ) - + + # Update road indices to match current segment when not jumping + next_road_index_A = jax.lax.cond( + jnp.logical_and(jnp.logical_not(is_jumping), updated_current_road == 0), + lambda _: segment, + lambda _: road_index_A, + operand=None, + ) + next_road_index_B = jax.lax.cond( - updated_current_road == 2, + jnp.logical_and(jnp.logical_not(is_jumping), updated_current_road == 1), + lambda _: segment, lambda _: road_index_B, + operand=None, + ) + + # Wrap Y position for looping track + wrapped_y = -((new_player_y * -1) % 1036) + + return Car( + position=EntityPosition( + x=final_player_x, + y=wrapped_y, + width=width, + height=height, + ), + speed=speed, + direction_x=car_direction_x, + current_road=updated_current_road, + road_index_A=next_road_index_A, + road_index_B=next_road_index_B, + type=car_type, + ) + + @partial(jax.jit, static_argnums=(0,)) + def _advance_car_core( + self, + position_x: chex.Array, + position_y: chex.Array, + road_index_A: chex.Array, + road_index_B: chex.Array, + current_road: chex.Array, + speed: chex.Array, + step_counter: chex.Array, + width: chex.Array, + height: chex.Array, + car_type: chex.Array, + ) -> Car: + """Simplified car advancement for enemy cars (no jumping/landing logic).""" + dividers = jnp.array([0, 1, 2, 4, 8, 16, 16]) + abs_speed = jnp.abs(speed) + speed_index = jnp.minimum(abs_speed, jnp.int32(dividers.shape[0] - 1)) + speed_divider = dividers[speed_index] + effective_divider = jnp.maximum(1, speed_divider) + period = jnp.maximum(1, 16 // effective_divider) + half_period = jnp.maximum(1, period // 2) + speed_sign = jax.lax.cond(speed != 0, lambda _: jax.lax.abs(speed) / speed, lambda _: jnp.array(0.0), operand=None) + + slope, b = self._get_slope_and_intercept_from_indices(current_road, road_index_A, road_index_B) + + direction_raw = jax.lax.cond( + current_road == 0, + lambda _: self.consts.FIRST_TRACK_CORNERS_X[road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_X[road_index_A], + lambda _: self.consts.SECOND_TRACK_CORNERS_X[road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[road_index_B], + operand=None, + ) + car_direction_x = jax.lax.cond(direction_raw > 0, lambda _: 1, lambda _: -1, operand=None) + + move_y = jnp.logical_and((step_counter % period) == (half_period % period), speed != 0) + move_x = jnp.logical_and((step_counter % period) == 0, speed != 0) + + step_size = jnp.where(speed_index >= self.consts.MAX_SPEED, 1.5, 1.0) + + position = EntityPosition(x=position_x, y=position_y, width=width, height=height) + + new_y = jax.lax.cond( + move_y, lambda _: jax.lax.cond( - self.consts.TRACK_CORNERS_Y[road_index_B] < new_player_y, - lambda _: road_index_B - 1, - lambda _: jax.lax.cond( - len(self.consts.TRACK_CORNERS_Y) == road_index_B + 1, - lambda _: jax.lax.cond( - self.consts.TRACK_CORNERS_Y[0] > new_player_y, - lambda _: 0, - lambda _: road_index_B, - operand=None, - ), - lambda _: jax.lax.cond( - self.consts.TRACK_CORNERS_Y[road_index_B+1] > new_player_y, - lambda _: road_index_B + 1, - lambda _: road_index_B, - operand=None, - ), - operand=None, - ), + self._is_on_line_for_position(position, slope, b, speed_sign, 1), + lambda _: position_y + speed_sign * -step_size, + lambda _: jnp.array(position_y, float), + operand=None, + ), + lambda _: jnp.array(position_y, float), + operand=None, + ) + + new_x = jax.lax.cond( + move_x, + lambda _: jax.lax.cond( + self._is_on_line_for_position(position, slope, b, speed_sign, 2), + lambda _: position_x + speed_sign * car_direction_x * step_size, + lambda _: jnp.array(position_x, float), operand=None, ), + lambda _: jnp.array(position_x, float), operand=None, ) - wrapped_y = -((new_player_y * -1) % 1036) + wrapped_y = -((new_y * -1) % 1036) + + # Update road segment indices based on new position + segment_from_y = self._get_road_segment(new_y) + + # Update road indices to track the current segment + next_road_index_A = jax.lax.cond( + current_road == 0, + lambda _: segment_from_y, + lambda _: road_index_A, + operand=None, + ) + + next_road_index_B = jax.lax.cond( + current_road == 1, + lambda _: segment_from_y, + lambda _: road_index_B, + operand=None, + ) return Car( position=EntityPosition( - x=new_player_x, + x=new_x, y=wrapped_y, width=width, height=height, ), speed=speed, direction_x=car_direction_x, - current_road=updated_current_road, + current_road=current_road, road_index_A=next_road_index_A, road_index_B=next_road_index_B, type=car_type, @@ -563,11 +748,7 @@ def select_type(rand_val): type_id_spawn = select_type(rand_type) # Calculate X position on road - def get_road_segment(y): - segments = jnp.sum(self.consts.TRACK_CORNERS_Y > y) - return jnp.clip(segments - 1, 0, len(self.consts.TRACK_CORNERS_Y) - 2) - - segment_spawn = get_road_segment(y_spawn) + segment_spawn = self._get_road_segment(y_spawn) x_spawn = jax.lax.cond( road_spawn == 0, lambda _: self._get_x_on_road(y_spawn, segment_spawn, self.consts.FIRST_TRACK_CORNERS_X), @@ -716,20 +897,145 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: operand=player_speed, ) - is_jumping = jnp.logical_or(jnp.logical_and(state.is_jumping, state.jump_cooldown > 0), jnp.logical_and(state.is_on_road, jnp.logical_and(player_speed >= 0, jnp.logical_and(state.jump_cooldown == 0, jump)))) + # Check if on a steep road section (no X direction change) and apply speed reduction + # This simulates steep road sections that require a jump to pass when going upward + is_on_steep_road = self._is_steep_road_segment( + state.player_car.current_road, + state.player_car.road_index_A, + state.player_car.road_index_B, + ) + # Only apply steep road penalty when: + # 1. Player is on a steep road section + # 2. Player is not jumping + # 3. Player has positive speed (going upward) + on_steep_going_up = jnp.logical_and( + is_on_steep_road, + jnp.logical_and( + jnp.logical_not(state.is_jumping), + player_speed > 0 + ) + ) + # Update steep road timer - increment when on steep road going up, reset otherwise + steep_road_timer = jax.lax.cond( + on_steep_going_up, + lambda _: state.steep_road_timer + 1, + lambda _: jnp.array(0, dtype=jnp.int32), + operand=None, + ) + # Only reduce speed when timer reaches the interval threshold + should_reduce_speed = jnp.logical_and( + on_steep_going_up, + steep_road_timer >= self.consts.STEEP_ROAD_SPEED_REDUCTION_INTERVAL + ) + # Gradually reduce speed toward -2 when on steep section without jumping + player_speed = jax.lax.cond( + should_reduce_speed, + lambda s: jnp.maximum(s - 1, jnp.int32(-2)), + lambda s: s, + operand=player_speed, + ) + # Reset timer after speed reduction + steep_road_timer = jax.lax.cond( + should_reduce_speed, + lambda _: jnp.array(0, dtype=jnp.int32), + lambda _: steep_road_timer, + operand=None, + ) + + can_start_jump = jnp.logical_and(state.jump_cooldown == 0, state.post_jump_cooldown == 0) + is_jumping = jnp.logical_or( + jnp.logical_and(state.is_jumping, state.jump_cooldown > 0), + jnp.logical_and(state.is_on_road, jnp.logical_and(player_speed >= 0, jnp.logical_and(can_start_jump, jump))), + ) + + # Detect when a new jump is starting (was not jumping, now is jumping) + starting_jump = jnp.logical_and(is_jumping, jnp.logical_not(state.is_jumping)) + + # Calculate jump slope at jump start (X change per Y step) + # Uses the road segment slope to follow the road trajectory + road_index = jax.lax.cond( + state.player_car.current_road == 0, + lambda _: state.player_car.road_index_A, + lambda _: state.player_car.road_index_B, + operand=None, + ) + + # Get corner coordinates for the current segment + # Segment goes from corner[road_index] to corner[road_index+1] + start_x = jax.lax.cond( + state.player_car.current_road == 0, + lambda _: self.consts.FIRST_TRACK_CORNERS_X[road_index], + lambda _: self.consts.SECOND_TRACK_CORNERS_X[road_index], + operand=None, + ) + end_x = jax.lax.cond( + state.player_car.current_road == 0, + lambda _: self.consts.FIRST_TRACK_CORNERS_X[road_index +1], + lambda _: self.consts.SECOND_TRACK_CORNERS_X[road_index +1], + operand=None, + ) + start_y = self.consts.TRACK_CORNERS_Y[road_index] + + end_y = jax.lax.cond( + jnp.equal(self.consts.FIRST_TRACK_CORNERS_X[road_index + 1], self.consts.FIRST_TRACK_CORNERS_X[road_index + 2]), + lambda _: self.consts.TRACK_CORNERS_Y[road_index + 2], + lambda _: self.consts.TRACK_CORNERS_Y[road_index + 1], + operand=None + ) + + # Calculate slope: how much X changes per unit Y change + delta_x = end_x - start_x + delta_y = end_y - start_y + # Avoid division by zero for horizontal segments + new_jump_slope = jax.lax.cond( + jnp.abs(delta_y) > 0.001, + lambda _: jnp.float32(delta_x) / jnp.float32(delta_y), + lambda _: jnp.float32(0.0), + operand=None, + ) + + # Lock slope at jump start, keep previous slope during jump + jump_slope = jax.lax.cond( + starting_jump, + lambda _: new_jump_slope, + lambda _: state.jump_slope, + operand=None, + ) + jump_cooldown = jax.lax.cond( state.jump_cooldown > 0, lambda s: s - 1, - lambda s: jax.lax.cond(is_jumping, - lambda _: self.consts.JUMP_FRAMES, - lambda _: 0, - operand=None), + lambda s: jax.lax.cond( + is_jumping, + lambda _: self.consts.JUMP_FRAMES, + lambda _: 0, + operand=None, + ), operand=state.jump_cooldown, ) + + post_jump_cooldown = jax.lax.cond( + jnp.logical_and(state.jump_cooldown == 1, jump_cooldown == 0), + lambda _: self.consts.POST_JUMP_DELAY, + lambda _: jax.lax.cond( + state.post_jump_cooldown > 0, + lambda s: s - 1, + lambda s: s, + operand=state.post_jump_cooldown, + ), + operand=None, + ) is_on_road = ~is_jumping is_landing = jnp.logical_and(state.jump_cooldown == 1, jump_cooldown == 0) - updated_player_car = self._advance_car_core( + respawn_cooldown = jax.lax.cond( + state.respawn_cooldown > 0, + lambda _: state.respawn_cooldown - 1, + lambda _: jnp.array(0, dtype=jnp.int32), + operand=None, + ) + + updated_player_car = self._advance_player_car( position_x=state.player_car.position.x, position_y=state.player_car.position.y, road_index_A=state.player_car.road_index_A, @@ -742,30 +1048,41 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: width=state.player_car.position.width, height=state.player_car.position.height, car_type=state.player_car.type, - landing_check=is_landing, + is_landing=is_landing, + stored_jump_slope=jump_slope, ) - return UpNDownState( - score=state.score, - difficulty=state.difficulty, + # Check if a speed-changing action (UP or DOWN) was taken + speed_action_taken = jnp.logical_or(up, down) + # Round starts only after a speed-changing action + round_started_now = jnp.logical_or(state.round_started, speed_action_taken) + + next_state = state._replace( + respawn_cooldown=respawn_cooldown, jump_cooldown=jump_cooldown, + post_jump_cooldown=post_jump_cooldown, is_jumping=is_jumping, is_on_road=is_on_road, player_car=updated_player_car, step_counter=state.step_counter + 1, - round_started=jnp.logical_or(state.round_started, player_speed != 0), + round_started=round_started_now, movement_steps=jax.lax.cond( - jnp.logical_or(state.round_started, player_speed != 0), - lambda s: state.movement_steps + 1, - lambda s: state.movement_steps, + round_started_now, + lambda _: state.movement_steps + 1, + lambda _: state.movement_steps, operand=None, ), - flags=state.flags, - flags_collected_mask=state.flags_collected_mask, - collectibles=state.collectibles, - collectible_spawn_timer=state.collectible_spawn_timer, - enemy_cars=state.enemy_cars, - enemy_spawn_timer=state.enemy_spawn_timer, + steep_road_timer=steep_road_timer, + jump_slope=jump_slope, + ) + + water_crash = jnp.logical_and(is_landing, updated_player_car.current_road == 2) + + return jax.lax.cond( + water_crash, + lambda _: self._respawn_after_collision(next_state, next_state.lives - 1), + lambda _: next_state, + operand=None, ) def _flag_step_main(self, state: UpNDownState) -> UpNDownState: @@ -778,25 +1095,10 @@ def _flag_step_main(self, state: UpNDownState) -> UpNDownState: state, new_player_y, player_x, current_road ) - return UpNDownState( + return state._replace( score=state.score + flag_score, - difficulty=state.difficulty, - jump_cooldown=state.jump_cooldown, - is_jumping=state.is_jumping, - is_on_road=state.is_on_road, - player_car=state.player_car, - lives=state.lives, - is_dead=state.is_dead, - respawn_timer=state.respawn_timer, - step_counter=state.step_counter, - round_started=state.round_started, - movement_steps=state.movement_steps, flags=new_flags, flags_collected_mask=new_flags_collected_mask, - collectibles=state.collectibles, - collectible_spawn_timer=state.collectible_spawn_timer, - enemy_cars=state.enemy_cars, - enemy_spawn_timer=state.enemy_spawn_timer, ) def _collectible_step_main(self, state: UpNDownState) -> UpNDownState: @@ -809,25 +1111,76 @@ def _collectible_step_main(self, state: UpNDownState) -> UpNDownState: state, new_player_y, player_x, current_road ) - return UpNDownState( + return state._replace( score=state.score + collectible_score, - difficulty=state.difficulty, - jump_cooldown=state.jump_cooldown, - is_jumping=state.is_jumping, - is_on_road=state.is_on_road, - player_car=state.player_car, - lives=state.lives, - is_dead=state.is_dead, - respawn_timer=state.respawn_timer, - step_counter=state.step_counter, - round_started=state.round_started, - movement_steps=state.movement_steps, - flags=state.flags, - flags_collected_mask=state.flags_collected_mask, collectibles=updated_collectibles, collectible_spawn_timer=new_collectible_timer, - enemy_cars=state.enemy_cars, - enemy_spawn_timer=state.enemy_spawn_timer, + ) + + def _initialize_collectibles(self) -> Collectible: + """Return a cleared collectible pool.""" + return Collectible( + y=jnp.zeros(self.consts.MAX_COLLECTIBLES), + x=jnp.zeros(self.consts.MAX_COLLECTIBLES), + road=jnp.zeros(self.consts.MAX_COLLECTIBLES, dtype=jnp.int32), + color_idx=jnp.zeros(self.consts.MAX_COLLECTIBLES, dtype=jnp.int32), + type_id=jnp.zeros(self.consts.MAX_COLLECTIBLES, dtype=jnp.int32), + active=jnp.zeros(self.consts.MAX_COLLECTIBLES, dtype=jnp.bool_), + ) + + @partial(jax.jit, static_argnums=(0,)) + def _initialize_enemies(self, key: chex.Array, player_start_y: chex.Array) -> EnemyCars: + """Seed the initial set of visible enemies around the player.""" + key_init, key_type, key_road, key_speed, key_sign = jax.random.split(key, 5) + + offsets = self.consts.INITIAL_ENEMY_BASE_OFFSET + self.consts.INITIAL_ENEMY_GAP * jnp.arange(self.consts.INITIAL_ENEMY_COUNT) + spawn_signs = jax.random.choice(key_sign, jnp.array([-1.0, 1.0]), shape=(self.consts.INITIAL_ENEMY_COUNT,)) + raw_spawn_y = player_start_y + spawn_signs * offsets + init_y = -(((raw_spawn_y) * -1) % self.consts.TRACK_LENGTH) + init_road = jax.random.randint(key_road, shape=(self.consts.INITIAL_ENEMY_COUNT,), minval=0, maxval=2) + + init_segments = jax.vmap(self._get_road_segment)(init_y) + + init_x = jax.vmap(lambda y, seg, road: jax.lax.cond( + road == 0, + lambda _: self._get_x_on_road(y, seg, self.consts.FIRST_TRACK_CORNERS_X), + lambda _: self._get_x_on_road(y, seg, self.consts.SECOND_TRACK_CORNERS_X), + operand=None, + ))(init_y, init_segments, init_road) + + init_type = jax.random.randint(key_type, shape=(self.consts.INITIAL_ENEMY_COUNT,), minval=0, maxval=4) + init_speed_mag = jax.random.randint(key_speed, shape=(self.consts.INITIAL_ENEMY_COUNT,), minval=self.consts.ENEMY_SPEED_MIN, maxval=self.consts.ENEMY_SPEED_MAX + 1) + init_speed_sign = jax.random.choice(key_init, jnp.array([-1, 1]), shape=(self.consts.INITIAL_ENEMY_COUNT,)) + init_speed = init_speed_mag * init_speed_sign + + def init_direction(seg, road): + raw = jax.lax.cond( + road == 0, + lambda _: self.consts.FIRST_TRACK_CORNERS_X[seg+1] - self.consts.FIRST_TRACK_CORNERS_X[seg], + lambda _: self.consts.SECOND_TRACK_CORNERS_X[seg+1] - self.consts.SECOND_TRACK_CORNERS_X[seg], + operand=None, + ) + return jax.lax.cond(raw > 0, lambda _: 1, lambda _: -1, operand=None) + + init_dir = jax.vmap(init_direction)(init_segments, init_road) + + pad = self.consts.MAX_ENEMY_CARS - self.consts.INITIAL_ENEMY_COUNT + + return EnemyCars( + position=EntityPosition( + x=jnp.concatenate([init_x, jnp.zeros(pad, dtype=jnp.float32)]), + y=jnp.concatenate([init_y, jnp.zeros(pad, dtype=jnp.float32)]), + width=jnp.full((self.consts.MAX_ENEMY_CARS,), self.consts.PLAYER_SIZE[0]), + height=jnp.full((self.consts.MAX_ENEMY_CARS,), self.consts.PLAYER_SIZE[1]), + ), + speed=jnp.concatenate([init_speed, jnp.zeros(pad, dtype=jnp.int32)]), + type=jnp.concatenate([init_type, jnp.zeros(pad, dtype=jnp.int32)]), + current_road=jnp.concatenate([init_road, jnp.zeros(pad, dtype=jnp.int32)]), + road_index_A=jnp.concatenate([init_segments, jnp.zeros(pad, dtype=jnp.int32)]), + road_index_B=jnp.concatenate([init_segments, jnp.zeros(pad, dtype=jnp.int32)]), + direction_x=jnp.concatenate([init_dir, jnp.zeros(pad, dtype=jnp.int32)]), + active=jnp.concatenate([jnp.ones(self.consts.INITIAL_ENEMY_COUNT, dtype=jnp.bool_), jnp.zeros(pad, dtype=jnp.bool_)]), + age=jnp.concatenate([jnp.zeros(self.consts.INITIAL_ENEMY_COUNT, dtype=jnp.int32), jnp.zeros(pad, dtype=jnp.int32)]), ) @partial(jax.jit, static_argnums=(0,)) @@ -861,11 +1214,7 @@ def _enemy_step_main(self, state: UpNDownState) -> UpNDownState: spawn_y = -(((raw_spawn_y) * -1) % 1036) spawn_road = jax.random.randint(key_spawn_direction, shape=(), minval=0, maxval=2) - def get_road_segment(y): - segments = jnp.sum(self.consts.TRACK_CORNERS_Y > y) - return jnp.clip(segments - 1, 0, len(self.consts.TRACK_CORNERS_Y) - 2) - - segment_spawn = get_road_segment(spawn_y) + segment_spawn = self._get_road_segment(spawn_y) spawn_x = jax.lax.cond( spawn_road == 0, lambda _: self._get_x_on_road(spawn_y, segment_spawn, self.consts.FIRST_TRACK_CORNERS_X), @@ -910,13 +1259,10 @@ def get_road_segment(y): road_index_B=rb, current_road=cr, speed=sp, - is_jumping=False, - is_on_road=True, step_counter=state.step_counter, width=self.consts.PLAYER_SIZE[0], height=self.consts.PLAYER_SIZE[1], car_type=tp, - landing_check=False, ) advanced_cars = jax.vmap(move_fn)( @@ -963,22 +1309,155 @@ def get_road_segment(y): age=enemy_age, ) + return state._replace( + enemy_cars=next_enemy_cars, + enemy_spawn_timer=spawn_timer, + ) + + @partial(jax.jit, static_argnums=(0,)) + def _respawn_after_collision(self, state: UpNDownState, new_lives: chex.Array) -> UpNDownState: + """Respawn the player on a random road while preserving score and flags.""" + base_key = jax.random.PRNGKey(1337) + key_spawn = jax.random.fold_in(base_key, state.step_counter) + road_key, enemy_key = jax.random.split(key_spawn, 2) + + player_start_y = jnp.array(0.0) + start_segment = jnp.array(0, dtype=jnp.int32) + respawn_road = jax.random.randint(road_key, shape=(), minval=0, maxval=2) + + start_x = jax.lax.cond( + respawn_road == 0, + lambda _: self._get_x_on_road(player_start_y, start_segment, self.consts.FIRST_TRACK_CORNERS_X), + lambda _: self._get_x_on_road(player_start_y, start_segment, self.consts.SECOND_TRACK_CORNERS_X), + operand=None, + ) + + enemy_cars = self._initialize_enemies(enemy_key, player_start_y) + collectibles = self._initialize_collectibles() + + player_car = Car( + position=EntityPosition( + x=jnp.asarray(start_x, dtype=jnp.float32), + y=jnp.asarray(player_start_y, dtype=jnp.float32), + width=self.consts.PLAYER_SIZE[0], + height=self.consts.PLAYER_SIZE[1], + ), + speed=jnp.array(0, dtype=jnp.int32), + direction_x=jnp.array(0, dtype=jnp.int32), + current_road=respawn_road, + road_index_A=start_segment, + road_index_B=start_segment, + type=jnp.array(0, dtype=jnp.int32), + ) + return UpNDownState( score=state.score, + lives=new_lives, + respawn_cooldown=jnp.array(self.consts.RESPAWN_HIDE_FRAMES, dtype=jnp.int32), difficulty=state.difficulty, - jump_cooldown=state.jump_cooldown, - is_jumping=state.is_jumping, - is_on_road=state.is_on_road, - player_car=state.player_car, + jump_cooldown=jnp.array(0, dtype=jnp.int32), + post_jump_cooldown=jnp.array(0, dtype=jnp.int32), + is_jumping=jnp.array(False), + is_on_road=jnp.array(True), + player_car=player_car, step_counter=state.step_counter, - round_started=state.round_started, - movement_steps=state.movement_steps, + round_started=jnp.array(False), + movement_steps=jnp.array(0), + steep_road_timer=jnp.array(0, dtype=jnp.int32), + jump_slope=jnp.array(0.0, dtype=jnp.float32), flags=state.flags, flags_collected_mask=state.flags_collected_mask, - collectibles=state.collectibles, - collectible_spawn_timer=state.collectible_spawn_timer, - enemy_cars=next_enemy_cars, - enemy_spawn_timer=spawn_timer, + collectibles=collectibles, + collectible_spawn_timer=jnp.array(0, dtype=jnp.int32), + enemy_cars=enemy_cars, + enemy_spawn_timer=jnp.array(self.consts.ENEMY_SPAWN_INTERVAL, dtype=jnp.int32), + ) + + @partial(jax.jit, static_argnums=(0,)) + def _enemy_collision_step_main(self, state: UpNDownState) -> UpNDownState: + """Handle collisions between the player and enemy cars. + + - While airborne, collisions are ignored except for the final jump frames, + where hitting an enemy despawns it and awards a bonus. + - On ground collisions, the player loses a life and the stage soft-resets + without clearing score or collected flags. + - Landing collisions use a larger distance and are road-independent (for crossings). + """ + player_x = state.player_car.position.x + player_y = state.player_car.position.y + + dx = jnp.abs(state.enemy_cars.position.x - player_x) + dy = jnp.abs(state.enemy_cars.position.y - player_y) + wrapped_dy = jnp.minimum(dy, self.consts.TRACK_LENGTH - dy) + + # For ground collision: only trigger when enemy position is within 3 pixels + ground_collision_distance = 3.0 + overlap_x_ground = dx <= ground_collision_distance + overlap_y_ground = wrapped_dy <= ground_collision_distance + # For landing collision: use larger distance and road-independent (for crossings) + overlap_x_landing = dx <= self.consts.LANDING_COLLISION_DISTANCE + overlap_y_landing = wrapped_dy <= self.consts.LANDING_COLLISION_DISTANCE + # For late jump collision: use original larger overlap based on car dimensions + overlap_x_jump = dx <= (state.player_car.position.width + state.enemy_cars.position.width) / 2.0 + overlap_y_jump = wrapped_dy <= (state.player_car.position.height + state.enemy_cars.position.height) / 2.0 + same_road = state.enemy_cars.current_road == state.player_car.current_road + + # Ground collision mask uses tight 3-pixel distance and same road + ground_collision_mask = jnp.logical_and(state.enemy_cars.active, jnp.logical_and(same_road, jnp.logical_and(overlap_x_ground, overlap_y_ground))) + # Landing collision mask uses larger distance and is road-independent (for crossings) + landing_collision_mask = jnp.logical_and(state.enemy_cars.active, jnp.logical_and(overlap_x_landing, overlap_y_landing)) + # Jump collision mask uses original larger overlap (for scoring when jumping on enemies) + jump_collision_mask = jnp.logical_and(state.enemy_cars.active, jnp.logical_and(same_road, jnp.logical_and(overlap_x_jump, overlap_y_jump))) + collision_mask = jump_collision_mask # For late jump scoring + + any_jump_collision = jnp.any(jump_collision_mask) + any_ground_collision = jnp.any(ground_collision_mask) + any_landing_collision = jnp.any(landing_collision_mask) + + # Check if player is in the landing phase (just landed from a jump) + is_landing_phase = jnp.logical_and(state.post_jump_cooldown > 0, state.post_jump_cooldown <= self.consts.POST_JUMP_DELAY) + + late_jump_window = jnp.logical_and(state.is_jumping, state.jump_cooldown <= self.consts.LATE_JUMP_COLLISION_FRAMES) + late_jump_collision = jnp.logical_and(any_jump_collision, late_jump_window) + grounded_collision = jnp.logical_and(any_ground_collision, jnp.logical_not(state.is_jumping)) + # Landing collision is road-independent and uses larger distance + landing_collision = jnp.logical_and(any_landing_collision, is_landing_phase) + + def handle_late_jump(): + hits = collision_mask.astype(jnp.int32) + bonus = jnp.sum(hits) * self.consts.LATE_JUMP_ENEMY_SCORE + new_enemy_active = jnp.logical_and(state.enemy_cars.active, jnp.logical_not(collision_mask)) + new_enemy_age = jnp.where(collision_mask, jnp.zeros_like(state.enemy_cars.age), state.enemy_cars.age) + new_enemy_cars = EnemyCars( + position=state.enemy_cars.position, + speed=state.enemy_cars.speed, + type=state.enemy_cars.type, + current_road=state.enemy_cars.current_road, + road_index_A=state.enemy_cars.road_index_A, + road_index_B=state.enemy_cars.road_index_B, + direction_x=state.enemy_cars.direction_x, + active=new_enemy_active, + age=new_enemy_age, + ) + + return state._replace(score=state.score + bonus, enemy_cars=new_enemy_cars) + + def handle_ground_collision(): + return self._respawn_after_collision(state, state.lives - 1) + + # Check for any collision that should cause respawn (ground or landing) + any_fatal_collision = jnp.logical_or(grounded_collision, landing_collision) + + return jax.lax.cond( + late_jump_collision, + lambda _: handle_late_jump(), + lambda _: jax.lax.cond( + any_fatal_collision, + lambda _: handle_ground_collision(), + lambda _: state, + operand=None, + ), + operand=None, ) def _passive_score_step_main(self, state: UpNDownState) -> UpNDownState: @@ -990,26 +1469,7 @@ def _passive_score_step_main(self, state: UpNDownState) -> UpNDownState: operand=None, ) - return UpNDownState( - score=state.score + bonus, - difficulty=state.difficulty, - jump_cooldown=state.jump_cooldown, - is_jumping=state.is_jumping, - is_on_road=state.is_on_road, - player_car=state.player_car, - lives=state.lives, - is_dead=state.is_dead, - respawn_timer=state.respawn_timer, - step_counter=state.step_counter, - round_started=state.round_started, - movement_steps=state.movement_steps, - flags=state.flags, - flags_collected_mask=state.flags_collected_mask, - collectibles=state.collectibles, - collectible_spawn_timer=state.collectible_spawn_timer, - enemy_cars=state.enemy_cars, - enemy_spawn_timer=state.enemy_spawn_timer, - ) + return state._replace(score=state.score + bonus) def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: @@ -1018,25 +1478,17 @@ def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: if key is None: key = jax.random.PRNGKey(42) + key, flag_key, enemy_key = jax.random.split(key, 3) # Evenly spread flags along the track with small jitter - key, subkey = jax.random.split(key) base_y = jnp.linspace(-900.0, -100.0, self.consts.NUM_FLAGS) - jitter = jax.random.uniform(subkey, shape=(self.consts.NUM_FLAGS,), minval=-40.0, maxval=40.0) + jitter = jax.random.uniform(flag_key, shape=(self.consts.NUM_FLAGS,), minval=-40.0, maxval=40.0) flag_y_offsets = base_y + jitter # Alternate roads 0/1 for variety flag_roads = jnp.arange(self.consts.NUM_FLAGS) % 2 - - - # Calculate which road segment each flag is on based on Y position - def get_road_segment(y): - # Find the segment where TRACK_CORNERS_Y[i] > y >= TRACK_CORNERS_Y[i+1] - segments = jnp.sum(self.consts.TRACK_CORNERS_Y > y) - return jnp.clip(segments - 1, 0, len(self.consts.TRACK_CORNERS_Y) - 2) - - flag_segments = jax.vmap(get_road_segment)(flag_y_offsets) + flag_segments = jax.vmap(self._get_road_segment)(flag_y_offsets) # Each flag color index corresponds to its position (0-7) flag_color_indices = jnp.arange(self.consts.NUM_FLAGS) @@ -1050,78 +1502,25 @@ def get_road_segment(y): ) # Initialize collectibles as all inactive (will spawn dynamically with mixed types) - collectibles = Collectible( - y=jnp.zeros(self.consts.MAX_COLLECTIBLES), - x=jnp.zeros(self.consts.MAX_COLLECTIBLES), - road=jnp.zeros(self.consts.MAX_COLLECTIBLES, dtype=jnp.int32), - color_idx=jnp.zeros(self.consts.MAX_COLLECTIBLES, dtype=jnp.int32), - type_id=jnp.zeros(self.consts.MAX_COLLECTIBLES, dtype=jnp.int32), - active=jnp.zeros(self.consts.MAX_COLLECTIBLES, dtype=jnp.bool_), - ) - - def get_road_segment(y): - segments = jnp.sum(self.consts.TRACK_CORNERS_Y > y) - return jnp.clip(segments - 1, 0, len(self.consts.TRACK_CORNERS_Y) - 2) + collectibles = self._initialize_collectibles() # Seed initial visible enemies spaced around the player - key_init, key_type, key_road, key_speed, key_sign = jax.random.split(key, 5) - player_start_y = 0.0 - offsets = self.consts.INITIAL_ENEMY_BASE_OFFSET + self.consts.INITIAL_ENEMY_GAP * jnp.arange(self.consts.INITIAL_ENEMY_COUNT) - spawn_signs = jax.random.choice(key_sign, jnp.array([-1.0, 1.0]), shape=(self.consts.INITIAL_ENEMY_COUNT,)) - raw_spawn_y = player_start_y + spawn_signs * offsets - init_y = -(((raw_spawn_y) * -1) % 1036) - init_road = jax.random.randint(key_road, shape=(self.consts.INITIAL_ENEMY_COUNT,), minval=0, maxval=2) - init_segments = jax.vmap(get_road_segment)(init_y) - init_x = jax.vmap(lambda y, seg, road: jax.lax.cond( - road == 0, - lambda _: self._get_x_on_road(y, seg, self.consts.FIRST_TRACK_CORNERS_X), - lambda _: self._get_x_on_road(y, seg, self.consts.SECOND_TRACK_CORNERS_X), - operand=None, - ))(init_y, init_segments, init_road) - init_type = jax.random.randint(key_type, shape=(self.consts.INITIAL_ENEMY_COUNT,), minval=0, maxval=4) - init_speed_mag = jax.random.randint(key_speed, shape=(self.consts.INITIAL_ENEMY_COUNT,), minval=self.consts.ENEMY_SPEED_MIN, maxval=self.consts.ENEMY_SPEED_MAX + 1) - init_speed_sign = jax.random.choice(key_init, jnp.array([-1, 1]), shape=(self.consts.INITIAL_ENEMY_COUNT,)) - init_speed = init_speed_mag * init_speed_sign - - def init_direction(seg, road): - raw = jax.lax.cond( - road == 0, - lambda _: self.consts.FIRST_TRACK_CORNERS_X[seg+1] - self.consts.FIRST_TRACK_CORNERS_X[seg], - lambda _: self.consts.SECOND_TRACK_CORNERS_X[seg+1] - self.consts.SECOND_TRACK_CORNERS_X[seg], - operand=None, - ) - return jax.lax.cond(raw > 0, lambda _: 1, lambda _: -1, operand=None) - - init_dir = jax.vmap(init_direction)(init_segments, init_road) - - pad = self.consts.MAX_ENEMY_CARS - self.consts.INITIAL_ENEMY_COUNT - enemy_cars = EnemyCars( - position=EntityPosition( - x=jnp.concatenate([init_x, jnp.zeros(pad, dtype=jnp.float32)]), - y=jnp.concatenate([init_y, jnp.zeros(pad, dtype=jnp.float32)]), - width=jnp.full((self.consts.MAX_ENEMY_CARS,), self.consts.PLAYER_SIZE[0]), - height=jnp.full((self.consts.MAX_ENEMY_CARS,), self.consts.PLAYER_SIZE[1]), - ), - speed=jnp.concatenate([init_speed, jnp.zeros(pad, dtype=jnp.int32)]), - type=jnp.concatenate([init_type, jnp.zeros(pad, dtype=jnp.int32)]), - current_road=jnp.concatenate([init_road, jnp.zeros(pad, dtype=jnp.int32)]), - road_index_A=jnp.concatenate([init_segments, jnp.zeros(pad, dtype=jnp.int32)]), - road_index_B=jnp.concatenate([init_segments, jnp.zeros(pad, dtype=jnp.int32)]), - direction_x=jnp.concatenate([init_dir, jnp.zeros(pad, dtype=jnp.int32)]), - active=jnp.concatenate([jnp.ones(self.consts.INITIAL_ENEMY_COUNT, dtype=jnp.bool_), jnp.zeros(pad, dtype=jnp.bool_)]), - age=jnp.concatenate([jnp.zeros(self.consts.INITIAL_ENEMY_COUNT, dtype=jnp.int32), jnp.zeros(pad, dtype=jnp.int32)]), - ) + player_start_y = jnp.array(0.0) + enemy_cars = self._initialize_enemies(enemy_key, player_start_y) state = UpNDownState( score=0, + lives=jnp.array(self.consts.INITIAL_LIVES, dtype=jnp.int32), + respawn_cooldown=jnp.array(0, dtype=jnp.int32), difficulty=self.consts.DIFFICULTIES[0], jump_cooldown=0, + post_jump_cooldown=0, is_jumping=False, is_on_road=True, player_car=Car( position=EntityPosition( - x=30, - y= 0, + x=jnp.asarray(30, dtype=jnp.float32), + y=jnp.asarray(0, dtype=jnp.float32), width=self.consts.PLAYER_SIZE[0], height=self.consts.PLAYER_SIZE[1], ), @@ -1135,6 +1534,8 @@ def init_direction(seg, road): step_counter=jnp.array(0), round_started=jnp.array(False), movement_steps=jnp.array(0), + steep_road_timer=jnp.array(0, dtype=jnp.int32), + jump_slope=jnp.array(0.0, dtype=jnp.float32), flags=flags, flags_collected_mask=jnp.zeros(self.consts.NUM_FLAGS, dtype=jnp.bool_), collectibles=collectibles, @@ -1156,6 +1557,7 @@ def step(self, state: UpNDownState, action: chex.Array) -> Tuple[UpNDownObservat state = self._completion_bonus_step(state) state = self._collectible_step_main(state) state = self._enemy_step_main(state) + state = self._enemy_collision_step_main(state) done = self._get_done(state) env_reward = self._get_reward(previous_state, state) @@ -1166,28 +1568,30 @@ def step(self, state: UpNDownState, action: chex.Array) -> Tuple[UpNDownObservat def render(self, state: UpNDownState) -> jnp.ndarray: - return self.renderer.render(state) + frame = self.renderer.render(state) + return jnp.asarray(frame, dtype=jnp.uint8) def _get_observation(self, state: UpNDownState): + # Clamp to screen-friendly coordinates so observation_space.contains passes + x = jnp.int32(jnp.clip(state.player_car.position.x, 0, 160)) + screen_y = jnp.int32(105) + player = EntityPosition( - x=jnp.array(state.player_car.position.x), - y=jnp.array(state.player_car.position.y), - width=jnp.array(self.consts.PLAYER_SIZE[0]), - height=jnp.array(self.consts.PLAYER_SIZE[1]), - ) - return UpNDownObservation( - player=player, + x=x, + y=screen_y, + width=jnp.int32(self.consts.PLAYER_SIZE[0]), + height=jnp.int32(self.consts.PLAYER_SIZE[1]), ) + return UpNDownObservation(player=player) @partial(jax.jit, static_argnums=(0,)) def obs_to_flat_array(self, obs: UpNDownObservation) -> jnp.ndarray: - return jnp.concatenate([ - obs.player.x.flatten(), - obs.player.y.flatten(), - obs.player.height.flatten(), - obs.player.width.flatten(), - ] - ) + return jnp.concatenate([ + jnp.asarray(obs.player.x, dtype=jnp.int32).reshape(-1), + jnp.asarray(obs.player.y, dtype=jnp.int32).reshape(-1), + jnp.asarray(obs.player.height, dtype=jnp.int32).reshape(-1), + jnp.asarray(obs.player.width, dtype=jnp.int32).reshape(-1), + ]) def action_space(self) -> spaces.Discrete: return spaces.Discrete(6) @@ -1212,20 +1616,19 @@ def image_space(self) -> spaces.Box: @partial(jax.jit, static_argnums=(0,)) def _get_info(self, state: UpNDownState, ) -> UpNDownInfo: - return UpNDownInfo(time=1) + return UpNDownInfo(time=jnp.asarray(state.step_counter, dtype=jnp.int32)) @partial(jax.jit, static_argnums=(0,)) def _get_reward(self, previous_state: UpNDownState, state: UpNDownState): - return state.score + base_delta = jnp.asarray(state.score - previous_state.score, dtype=jnp.float32) + if self.reward_funcs: + extras = jnp.sum(jnp.array([fn(previous_state, state) for fn in self.reward_funcs], dtype=jnp.float32)) + return base_delta + extras + return base_delta @partial(jax.jit, static_argnums=(0,)) def _get_done(self, state: UpNDownState) -> bool: - return jnp.logical_or( - state.lives <= 0, - jnp.all(state.flags_collected_mask), -) - - + return state.lives <= 0 class UpNDownRenderer(JAXGameRenderer): def __init__(self, consts: UpNDownConstants = None): @@ -1357,6 +1760,29 @@ def _compute_flag_palette_ids(self) -> jnp.ndarray: """Precompute palette indices for each flag color without special-casing pink.""" return jnp.array([self._find_palette_id(color) for color in self.consts.FLAG_COLORS], dtype=jnp.int32) + @partial(jax.jit, static_argnums=(0,)) + def _jump_arc_offset(self, jump_cooldown: chex.Array) -> chex.Array: + """Return a simple parabolic jump height based on remaining jump frames.""" + total = jnp.array(self.consts.JUMP_FRAMES, dtype=jnp.float32) + remaining = jnp.array(jump_cooldown, dtype=jnp.float32) + progress = jnp.clip((total - remaining) / jnp.maximum(total, 1.0), 0.0, 1.0) + centered = (progress - 0.5) * 2.0 + return self.consts.JUMP_ARC_HEIGHT * (1.0 - centered * centered) + + def _get_x_on_road(self, y: chex.Array, road_segment: chex.Array, track_corners_x: chex.Array) -> chex.Array: + """Linear interpolation of x along the given road segment for y.""" + y1 = self.consts.TRACK_CORNERS_Y[road_segment] + y2 = self.consts.TRACK_CORNERS_Y[road_segment + 1] + x1 = track_corners_x[road_segment] + x2 = track_corners_x[road_segment + 1] + t = jax.lax.cond( + y2 != y1, + lambda _: (y - y1) / (y2 - y1), + lambda _: 0.0, + operand=None, + ) + return x1 + t * (x2 - x1) + def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.ndarray, bottomBlockSprite: jnp.ndarray, tempPointer: jnp.ndarray, blackoutSquare: jnp.ndarray) -> tuple[list, list[str]]: """Returns the asset manifest and ordered road files.""" road_dir = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/roads" @@ -1442,223 +1868,214 @@ def combine(i, acc): raster = jax.lax.fori_loop(0, total_segments, combine, raster) - def select_enemy_mask(enemy_type: chex.Array, going_left: chex.Array): - left_mask = self.enemy_left_masks[enemy_type] - right_mask = self.enemy_right_masks[enemy_type] - return jax.lax.cond(going_left, lambda _: left_mask, lambda _: right_mask, operand=None) - - def render_enemy(carry, enemy_idx): - raster = carry - enemy_active = state.enemy_cars.active[enemy_idx] - enemy_x = state.enemy_cars.position.x[enemy_idx] - enemy_y = state.enemy_cars.position.y[enemy_idx] - enemy_type = state.enemy_cars.type[enemy_idx] - direction_x = state.enemy_cars.direction_x[enemy_idx] - screen_y = 105 + (enemy_y - state.player_car.position.y) - is_visible = jnp.logical_and(enemy_active, jnp.logical_and(screen_y > 25, screen_y < 195)) - enemy_mask = select_enemy_mask(enemy_type, direction_x < 0) - - raster = jax.lax.cond( - is_visible, - lambda r: self.jr.render_at(r, enemy_x.astype(jnp.int32), screen_y.astype(jnp.int32), enemy_mask), - lambda r: r, - operand=raster, + hide_world = state.respawn_cooldown > 0 + + # During respawn hide, only show the road/background to emulate an initial road state. + def render_roads_only(): + return raster + + def render_full_scene(): + def select_enemy_mask(enemy_type: chex.Array, going_left: chex.Array): + left_mask = self.enemy_left_masks[enemy_type] + right_mask = self.enemy_right_masks[enemy_type] + return jax.lax.cond(going_left, lambda _: left_mask, lambda _: right_mask, operand=None) + + def render_enemy(carry, enemy_idx): + raster = carry + enemy_active = state.enemy_cars.active[enemy_idx] + enemy_x = state.enemy_cars.position.x[enemy_idx] + enemy_y = state.enemy_cars.position.y[enemy_idx] + enemy_type = state.enemy_cars.type[enemy_idx] + direction_x = state.enemy_cars.direction_x[enemy_idx] + screen_y = 105 + (enemy_y - state.player_car.position.y) + is_visible = jnp.logical_and(enemy_active, jnp.logical_and(screen_y > 25, screen_y < 195)) + enemy_mask = select_enemy_mask(enemy_type, direction_x < 0) + + raster = jax.lax.cond( + is_visible, + lambda r: self.jr.render_at(r, enemy_x.astype(jnp.int32), screen_y.astype(jnp.int32), enemy_mask), + lambda r: r, + operand=raster, + ) + return raster, None + + raster_enemies, _ = jax.lax.scan(render_enemy, raster, jnp.arange(self.consts.MAX_ENEMY_CARS)) + + jump_offset = jax.lax.cond( + state.is_jumping, + lambda _: self._jump_arc_offset(state.jump_cooldown), + lambda _: jnp.array(0.0, dtype=jnp.float32), + operand=None, ) - return raster, None - raster, _ = jax.lax.scan(render_enemy, raster, jnp.arange(self.consts.MAX_ENEMY_CARS)) + player_screen_y = jnp.int32(105 - jump_offset) + player_mask = self.SHAPE_MASKS["player"] + raster_player = self.jr.render_at(raster_enemies, state.player_car.position.x, player_screen_y, player_mask) - player_mask = self.SHAPE_MASKS["player"] - raster = self.jr.render_at(raster, state.player_car.position.x, 105, player_mask) + wall_top_mask = self.SHAPE_MASKS["wall_top"] + raster_wall_top = self.jr.render_at(raster_player, 0, 0, wall_top_mask) - wall_top_mask = self.SHAPE_MASKS["wall_top"] - raster = self.jr.render_at(raster, 0, 0, wall_top_mask) + wall_bottom_mask = self.SHAPE_MASKS["wall_bottom"] + raster_wall_bottom = self.jr.render_at(raster_wall_top, 0, 210 - wall_bottom_mask.shape[0], wall_bottom_mask) - wall_bottom_mask = self.SHAPE_MASKS["wall_bottom"] - raster = self.jr.render_at(raster, 0, 210 - wall_bottom_mask.shape[0], wall_bottom_mask) + all_flags_top_mask = self.SHAPE_MASKS["all_flags_top"] + raster_flags_top = self.jr.render_at(raster_wall_bottom, 10, 20, all_flags_top_mask) - all_flags_top_mask = self.SHAPE_MASKS["all_flags_top"] - raster = self.jr.render_at(raster, 10, 20, all_flags_top_mask) + return raster_flags_top - # Render score centered at the top using dedicated score digit sprites - score_digits = self.jr.int_to_digits(state.score, max_digits=self.score_max_digits) - non_zero_mask = score_digits != 0 - has_non_zero = jnp.any(non_zero_mask) - first_non_zero = jnp.argmax(non_zero_mask) - start_index = jax.lax.select(has_non_zero, first_non_zero, self.score_max_digits - 1) - num_to_render = jax.lax.select(has_non_zero, self.score_max_digits - start_index, 1) - total_width = num_to_render * self.score_digit_spacing - score_x = self.score_center_x - (total_width // 2) + def render_rest(raster_input): + # Render score centered at the top using dedicated score digit sprites + score_digits = self.jr.int_to_digits(state.score, max_digits=self.score_max_digits) + non_zero_mask = score_digits != 0 + has_non_zero = jnp.any(non_zero_mask) + first_non_zero = jnp.argmax(non_zero_mask) + start_index = jax.lax.select(has_non_zero, first_non_zero, self.score_max_digits - 1) + num_to_render = jax.lax.select(has_non_zero, self.score_max_digits - start_index, 1) - raster = self.jr.render_label_selective( - raster, - jnp.int32(score_x), - self.score_render_y, - score_digits, - self.score_digit_masks, - start_index, - num_to_render, - spacing=self.score_digit_spacing, - max_digits_to_render=self.score_max_digits, - ) + total_width = num_to_render * self.score_digit_spacing + score_x = self.score_center_x - (total_width // 2) - # Render flags on the road - flag_pole_mask = self.SHAPE_MASKS["flag_pole"] - - def render_flag(carry, flag_idx): - raster = carry - flag_y = state.flags.y[flag_idx] - flag_road = state.flags.road[flag_idx] - flag_segment = state.flags.road_segment[flag_idx] - flag_collected = state.flags.collected[flag_idx] - flag_color_idx = state.flags.color_idx[flag_idx] - - # Calculate flag X position on its road - flag_x = jax.lax.cond( - flag_road == 0, - lambda _: self._get_flag_x_on_road(flag_y, flag_segment, self.consts.FIRST_TRACK_CORNERS_X), - lambda _: self._get_flag_x_on_road(flag_y, flag_segment, self.consts.SECOND_TRACK_CORNERS_X), - operand=None, - ) - - # Calculate screen Y position relative to player - # The player is always rendered at Y=105, so flags scroll based on player position - screen_y = 105 + (flag_y - state.player_car.position.y) - - # Check if flag is visible on screen and not collected - is_visible = jnp.logical_and( - jnp.logical_and(screen_y > 25, screen_y < 195), - ~flag_collected - ) - - # Colorize the base flag mask - color_id = self.flag_palette_ids[flag_color_idx] - colored_flag_mask = jnp.where( - self.flag_solid_mask, - color_id, - self.flag_base_mask, + raster_score = self.jr.render_label_selective( + raster_input, + jnp.int32(score_x), + self.score_render_y, + score_digits, + self.score_digit_masks, + start_index, + num_to_render, + spacing=self.score_digit_spacing, + max_digits_to_render=self.score_max_digits, ) + + # Render flags on the road + flag_pole_mask = self.SHAPE_MASKS["flag_pole"] - # Render flag if visible - raster = jax.lax.cond( - is_visible, - lambda r: self.jr.render_at( - self.jr.render_at(r, flag_x.astype(jnp.int32), screen_y.astype(jnp.int32), colored_flag_mask), - (flag_x + 5).astype(jnp.int32), screen_y.astype(jnp.int32), flag_pole_mask - ), - lambda r: r, - operand=raster, - ) - return raster, None - - raster, _ = jax.lax.scan(render_flag, raster, jnp.arange(self.consts.NUM_FLAGS)) - - # Black out collected flags at the top - blackout_mask = self.SHAPE_MASKS["blackout_square"] - - def render_blackout(carry, flag_idx): - raster = carry - flag_collected = state.flags_collected_mask[flag_idx] - blackout_x = self.consts.FLAG_TOP_X_POSITIONS[flag_idx] - blackout_y = self.consts.FLAG_TOP_Y + def render_flag(carry, flag_idx): + raster = carry + flag_y = state.flags.y[flag_idx] + flag_road = state.flags.road[flag_idx] + flag_segment = state.flags.road_segment[flag_idx] + flag_collected = state.flags.collected[flag_idx] + flag_color_idx = state.flags.color_idx[flag_idx] + + flag_x = jax.lax.cond( + flag_road == 0, + lambda _: self._get_x_on_road(flag_y, flag_segment, self.consts.FIRST_TRACK_CORNERS_X), + lambda _: self._get_x_on_road(flag_y, flag_segment, self.consts.SECOND_TRACK_CORNERS_X), + operand=None, + ) + screen_y = 105 + (flag_y - state.player_car.position.y) + is_visible = jnp.logical_and( + jnp.logical_and(screen_y > 25, screen_y < 195), + ~flag_collected + ) + color_id = self.flag_palette_ids[flag_color_idx] + colored_flag_mask = jnp.where( + self.flag_solid_mask, + color_id, + self.flag_base_mask, + ) + raster = jax.lax.cond( + is_visible, + lambda r: self.jr.render_at( + self.jr.render_at(r, flag_x.astype(jnp.int32), screen_y.astype(jnp.int32), colored_flag_mask), + (flag_x + 5).astype(jnp.int32), screen_y.astype(jnp.int32), flag_pole_mask + ), + lambda r: r, + operand=raster, + ) + return raster, None - raster = jax.lax.cond( - flag_collected, - lambda r: self.jr.render_at(r, blackout_x, blackout_y, blackout_mask), - lambda r: r, - operand=raster, - ) - return raster, None - - raster, _ = jax.lax.scan(render_blackout, raster, jnp.arange(self.consts.NUM_FLAGS)) - - # Render collectibles (unified for all types: cherry, balloon, lollypop, ice cream) - def render_collectible(carry, collectible_idx): - raster = carry - collectible_y = state.collectibles.y[collectible_idx] - collectible_x = state.collectibles.x[collectible_idx] - collectible_active = state.collectibles.active[collectible_idx] - collectible_color_idx = state.collectibles.color_idx[collectible_idx] - collectible_type_id = state.collectibles.type_id[collectible_idx] + raster_flags, _ = jax.lax.scan(render_flag, raster_score, jnp.arange(self.consts.NUM_FLAGS)) - # Calculate screen Y position relative to player - screen_y = 105 + (collectible_y - state.player_car.position.y) + blackout_mask = self.SHAPE_MASKS["blackout_square"] - # Check if collectible is visible on screen and active - is_visible = jnp.logical_and( - jnp.logical_and(screen_y > 25, screen_y < 195), - collectible_active - ) + def render_blackout(carry, flag_idx): + raster = carry + flag_collected = state.flags_collected_mask[flag_idx] + blackout_x = self.consts.FLAG_TOP_X_POSITIONS[flag_idx] + blackout_y = self.consts.FLAG_TOP_Y + raster = jax.lax.cond( + flag_collected, + lambda r: self.jr.render_at(r, blackout_x, blackout_y, blackout_mask), + lambda r: r, + operand=raster, + ) + return raster, None - # Select sprite based on type_id - # type_id: 0=cherry, 1=balloon, 2=lollypop, 3=ice_cream - def get_sprite_and_mask(type_id): - cherry_result = (self.cherry_base_mask, self.cherry_solid_mask, self.cherry_palette_ids) - balloon_result = (self.balloon_base_mask, self.balloon_solid_mask, self.balloon_palette_ids) - lollypop_result = (self.lollypop_base_mask, self.lollypop_solid_mask, self.lollypop_palette_ids) - ice_cream_result = (self.ice_cream_base_mask, self.ice_cream_solid_mask, self.ice_cream_palette_ids) - - # Use conditional branching to select sprite - result = jax.lax.cond( - type_id == self.consts.COLLECTIBLE_TYPE_CHERRY, - lambda _: cherry_result, - lambda _: jax.lax.cond( - type_id == self.consts.COLLECTIBLE_TYPE_BALLOON, - lambda _: balloon_result, + raster_blackout, _ = jax.lax.scan(render_blackout, raster_flags, jnp.arange(self.consts.NUM_FLAGS)) + + def render_collectible(carry, collectible_idx): + raster = carry + collectible_y = state.collectibles.y[collectible_idx] + collectible_x = state.collectibles.x[collectible_idx] + collectible_active = state.collectibles.active[collectible_idx] + collectible_color_idx = state.collectibles.color_idx[collectible_idx] + collectible_type_id = state.collectibles.type_id[collectible_idx] + screen_y = 105 + (collectible_y - state.player_car.position.y) + is_visible = jnp.logical_and( + jnp.logical_and(screen_y > 25, screen_y < 195), + collectible_active + ) + + def get_sprite_and_mask(type_id): + cherry_result = (self.cherry_base_mask, self.cherry_solid_mask, self.cherry_palette_ids) + balloon_result = (self.balloon_base_mask, self.balloon_solid_mask, self.balloon_palette_ids) + lollypop_result = (self.lollypop_base_mask, self.lollypop_solid_mask, self.lollypop_palette_ids) + ice_cream_result = (self.ice_cream_base_mask, self.ice_cream_solid_mask, self.ice_cream_palette_ids) + return jax.lax.cond( + type_id == self.consts.COLLECTIBLE_TYPE_CHERRY, + lambda _: cherry_result, lambda _: jax.lax.cond( - type_id == self.consts.COLLECTIBLE_TYPE_LOLLYPOP, - lambda _: lollypop_result, - lambda _: ice_cream_result, + type_id == self.consts.COLLECTIBLE_TYPE_BALLOON, + lambda _: balloon_result, + lambda _: jax.lax.cond( + type_id == self.consts.COLLECTIBLE_TYPE_LOLLYPOP, + lambda _: lollypop_result, + lambda _: ice_cream_result, + operand=None, + ), operand=None, ), operand=None, - ), - operand=None, + ) + + base_mask, solid_mask, palette_ids = get_sprite_and_mask(collectible_type_id) + color_id = palette_ids[collectible_color_idx] + colored_mask = jnp.where( + (base_mask != self.jr.TRANSPARENT_ID) & (base_mask != 0), + color_id, + base_mask, ) - return result - - base_mask, solid_mask, palette_ids = get_sprite_and_mask(collectible_type_id) - - # Only colorize inner pixels, keep black edges (palette ID 0 is black) - color_id = palette_ids[collectible_color_idx] - colored_mask = jnp.where( - (base_mask != self.jr.TRANSPARENT_ID) & (base_mask != 0), - color_id, - base_mask, - ) - - # Render collectible if visible - raster = jax.lax.cond( - is_visible, - lambda r: self.jr.render_at(r, collectible_x.astype(jnp.int32), screen_y.astype(jnp.int32), colored_mask), - lambda r: r, - operand=raster, - ) - return raster, None - - raster, _ = jax.lax.scan(render_collectible, raster, jnp.arange(self.consts.MAX_COLLECTIBLES)) + raster = jax.lax.cond( + is_visible, + lambda r: self.jr.render_at(r, collectible_x.astype(jnp.int32), screen_y.astype(jnp.int32), colored_mask), + lambda r: r, + operand=raster, + ) + return raster, None - all_lives_bottom_mask = self.SHAPE_MASKS["all_lives_bottom"] - raster = self.jr.render_at(raster, 10, 195, all_lives_bottom_mask) + raster_collectibles, _ = jax.lax.scan(render_collectible, raster_blackout, jnp.arange(self.consts.MAX_COLLECTIBLES)) - wall_bottom_mask = self.SHAPE_MASKS["tempPointer"] - raster = self.jr.render_at(raster, 140, 25, wall_bottom_mask) + all_lives_bottom_mask = self.SHAPE_MASKS["all_lives_bottom"] + raster_lives = self.jr.render_at(raster_collectibles, 10, 195, all_lives_bottom_mask) - return self.jr.render_from_palette(raster, self.PALETTE) - - def _get_flag_x_on_road(self, y: chex.Array, road_segment: chex.Array, track_corners_x: chex.Array) -> chex.Array: - """Calculate the X position on a road given a Y coordinate and road segment.""" - y1 = self.consts.TRACK_CORNERS_Y[road_segment] - y2 = self.consts.TRACK_CORNERS_Y[road_segment + 1] - x1 = track_corners_x[road_segment] - x2 = track_corners_x[road_segment + 1] - - # Linear interpolation: x = x1 + (y - y1) * (x2 - x1) / (y2 - y1) - t = jax.lax.cond( - y2 != y1, - lambda _: (y - y1) / (y2 - y1), - lambda _: 0.0, + wall_bottom_mask = self.SHAPE_MASKS["tempPointer"] + raster_pointer = self.jr.render_at(raster_lives, 140, 25, wall_bottom_mask) + + return self.jr.render_from_palette(raster_pointer, self.PALETTE) + + base_scene = jax.lax.cond( + hide_world, + lambda _: render_roads_only(), + lambda _: render_full_scene(), operand=None, ) - return x1 + t * (x2 - x1) \ No newline at end of file + + return jax.lax.cond( + hide_world, + lambda _: self.jr.render_from_palette(base_scene, self.PALETTE), + lambda _: render_rest(base_scene), + operand=None, + ) \ No newline at end of file From 2cd73b06605f2319fa9204a6b414e545982b4454 Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sat, 20 Dec 2025 21:33:07 +0100 Subject: [PATCH 57/76] add missing live counter to game --- src/jaxatari/games/jax_upndown.py | 391 +++++++++++++++--------------- 1 file changed, 189 insertions(+), 202 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 2d8de2915..52e5cd88a 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -19,7 +19,6 @@ class UpNDownConstants(NamedTuple): ACTION_REPEAT_PROBS: float = 0.25 MAX_SPEED: int = 6 INITIAL_LIVES: int = 5 - RESPAWN_HIDE_FRAMES: int = 8 JUMP_ARC_HEIGHT: float = 18.0 # Enemy spawning and movement MAX_ENEMY_CARS: int = 8 @@ -71,6 +70,10 @@ class UpNDownConstants(NamedTuple): FLAG_TOP_Y: int = 20 FLAG_BLACKOUT_SIZE: Tuple[int, int] = (14, 14) # Size of blackout square FLAG_COLLECTION_SCORE: int = 75 # Points awarded for collecting a flag + # Life display constants - positions of life cars at the bottom + LIFE_BOTTOM_X_POSITIONS: chex.Array = jnp.array([13, 18, 25, 33, 33]) # X positions for 5 life cars + LIFE_BOTTOM_Y: int = 195 + LIFE_BLACKOUT_SIZE: Tuple[int, int] = (14, 14) # Size of blackout square for lives PICKUP_SCORE: int = 100 # Points awarded for jumping on a pickup truck FLAG_CARRIER_SCORE: int = 125 # Points awarded for jumping on a flag carrier CAMARO_SCORE: int = 150 # Points awarded for jumping on a camaro @@ -147,7 +150,6 @@ class EnemyCars(NamedTuple): class UpNDownState(NamedTuple): score: chex.Array lives: chex.Array - respawn_cooldown: chex.Array difficulty: chex.Array jump_cooldown: chex.Array post_jump_cooldown: chex.Array @@ -1028,13 +1030,6 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: is_on_road = ~is_jumping is_landing = jnp.logical_and(state.jump_cooldown == 1, jump_cooldown == 0) - respawn_cooldown = jax.lax.cond( - state.respawn_cooldown > 0, - lambda _: state.respawn_cooldown - 1, - lambda _: jnp.array(0, dtype=jnp.int32), - operand=None, - ) - updated_player_car = self._advance_player_car( position_x=state.player_car.position.x, position_y=state.player_car.position.y, @@ -1058,7 +1053,6 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: round_started_now = jnp.logical_or(state.round_started, speed_action_taken) next_state = state._replace( - respawn_cooldown=respawn_cooldown, jump_cooldown=jump_cooldown, post_jump_cooldown=post_jump_cooldown, is_jumping=is_jumping, @@ -1353,7 +1347,6 @@ def _respawn_after_collision(self, state: UpNDownState, new_lives: chex.Array) - return UpNDownState( score=state.score, lives=new_lives, - respawn_cooldown=jnp.array(self.consts.RESPAWN_HIDE_FRAMES, dtype=jnp.int32), difficulty=state.difficulty, jump_cooldown=jnp.array(0, dtype=jnp.int32), post_jump_cooldown=jnp.array(0, dtype=jnp.int32), @@ -1511,7 +1504,6 @@ def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: state = UpNDownState( score=0, lives=jnp.array(self.consts.INITIAL_LIVES, dtype=jnp.int32), - respawn_cooldown=jnp.array(0, dtype=jnp.int32), difficulty=self.consts.DIFFICULTIES[0], jump_cooldown=0, post_jump_cooldown=0, @@ -1868,214 +1860,209 @@ def combine(i, acc): raster = jax.lax.fori_loop(0, total_segments, combine, raster) - hide_world = state.respawn_cooldown > 0 - - # During respawn hide, only show the road/background to emulate an initial road state. - def render_roads_only(): - return raster - - def render_full_scene(): - def select_enemy_mask(enemy_type: chex.Array, going_left: chex.Array): - left_mask = self.enemy_left_masks[enemy_type] - right_mask = self.enemy_right_masks[enemy_type] - return jax.lax.cond(going_left, lambda _: left_mask, lambda _: right_mask, operand=None) - - def render_enemy(carry, enemy_idx): - raster = carry - enemy_active = state.enemy_cars.active[enemy_idx] - enemy_x = state.enemy_cars.position.x[enemy_idx] - enemy_y = state.enemy_cars.position.y[enemy_idx] - enemy_type = state.enemy_cars.type[enemy_idx] - direction_x = state.enemy_cars.direction_x[enemy_idx] - screen_y = 105 + (enemy_y - state.player_car.position.y) - is_visible = jnp.logical_and(enemy_active, jnp.logical_and(screen_y > 25, screen_y < 195)) - enemy_mask = select_enemy_mask(enemy_type, direction_x < 0) - - raster = jax.lax.cond( - is_visible, - lambda r: self.jr.render_at(r, enemy_x.astype(jnp.int32), screen_y.astype(jnp.int32), enemy_mask), - lambda r: r, - operand=raster, - ) - return raster, None - - raster_enemies, _ = jax.lax.scan(render_enemy, raster, jnp.arange(self.consts.MAX_ENEMY_CARS)) - - jump_offset = jax.lax.cond( - state.is_jumping, - lambda _: self._jump_arc_offset(state.jump_cooldown), - lambda _: jnp.array(0.0, dtype=jnp.float32), - operand=None, + def select_enemy_mask(enemy_type: chex.Array, going_left: chex.Array): + left_mask = self.enemy_left_masks[enemy_type] + right_mask = self.enemy_right_masks[enemy_type] + return jax.lax.cond(going_left, lambda _: left_mask, lambda _: right_mask, operand=None) + + def render_enemy(carry, enemy_idx): + raster = carry + enemy_active = state.enemy_cars.active[enemy_idx] + enemy_x = state.enemy_cars.position.x[enemy_idx] + enemy_y = state.enemy_cars.position.y[enemy_idx] + enemy_type = state.enemy_cars.type[enemy_idx] + direction_x = state.enemy_cars.direction_x[enemy_idx] + screen_y = 105 + (enemy_y - state.player_car.position.y) + is_visible = jnp.logical_and(enemy_active, jnp.logical_and(screen_y > 25, screen_y < 195)) + enemy_mask = select_enemy_mask(enemy_type, direction_x < 0) + + raster = jax.lax.cond( + is_visible, + lambda r: self.jr.render_at(r, enemy_x.astype(jnp.int32), screen_y.astype(jnp.int32), enemy_mask), + lambda r: r, + operand=raster, ) + return raster, None - player_screen_y = jnp.int32(105 - jump_offset) - player_mask = self.SHAPE_MASKS["player"] - raster_player = self.jr.render_at(raster_enemies, state.player_car.position.x, player_screen_y, player_mask) + raster_enemies, _ = jax.lax.scan(render_enemy, raster, jnp.arange(self.consts.MAX_ENEMY_CARS)) - wall_top_mask = self.SHAPE_MASKS["wall_top"] - raster_wall_top = self.jr.render_at(raster_player, 0, 0, wall_top_mask) + jump_offset = jax.lax.cond( + state.is_jumping, + lambda _: self._jump_arc_offset(state.jump_cooldown), + lambda _: jnp.array(0.0, dtype=jnp.float32), + operand=None, + ) - wall_bottom_mask = self.SHAPE_MASKS["wall_bottom"] - raster_wall_bottom = self.jr.render_at(raster_wall_top, 0, 210 - wall_bottom_mask.shape[0], wall_bottom_mask) + player_screen_y = jnp.int32(105 - jump_offset) + player_mask = self.SHAPE_MASKS["player"] + raster_player = self.jr.render_at(raster_enemies, state.player_car.position.x, player_screen_y, player_mask) - all_flags_top_mask = self.SHAPE_MASKS["all_flags_top"] - raster_flags_top = self.jr.render_at(raster_wall_bottom, 10, 20, all_flags_top_mask) + wall_top_mask = self.SHAPE_MASKS["wall_top"] + raster_wall_top = self.jr.render_at(raster_player, 0, 0, wall_top_mask) - return raster_flags_top + wall_bottom_mask = self.SHAPE_MASKS["wall_bottom"] + raster_wall_bottom = self.jr.render_at(raster_wall_top, 0, 210 - wall_bottom_mask.shape[0], wall_bottom_mask) + all_flags_top_mask = self.SHAPE_MASKS["all_flags_top"] + raster_flags_top = self.jr.render_at(raster_wall_bottom, 10, 20, all_flags_top_mask) - def render_rest(raster_input): - # Render score centered at the top using dedicated score digit sprites - score_digits = self.jr.int_to_digits(state.score, max_digits=self.score_max_digits) - non_zero_mask = score_digits != 0 - has_non_zero = jnp.any(non_zero_mask) - first_non_zero = jnp.argmax(non_zero_mask) - start_index = jax.lax.select(has_non_zero, first_non_zero, self.score_max_digits - 1) - num_to_render = jax.lax.select(has_non_zero, self.score_max_digits - start_index, 1) + # Render score centered at the top using dedicated score digit sprites + score_digits = self.jr.int_to_digits(state.score, max_digits=self.score_max_digits) + non_zero_mask = score_digits != 0 + has_non_zero = jnp.any(non_zero_mask) + first_non_zero = jnp.argmax(non_zero_mask) + start_index = jax.lax.select(has_non_zero, first_non_zero, self.score_max_digits - 1) + num_to_render = jax.lax.select(has_non_zero, self.score_max_digits - start_index, 1) - total_width = num_to_render * self.score_digit_spacing - score_x = self.score_center_x - (total_width // 2) + total_width = num_to_render * self.score_digit_spacing + score_x = self.score_center_x - (total_width // 2) - raster_score = self.jr.render_label_selective( - raster_input, - jnp.int32(score_x), - self.score_render_y, - score_digits, - self.score_digit_masks, - start_index, - num_to_render, - spacing=self.score_digit_spacing, - max_digits_to_render=self.score_max_digits, - ) + raster_score = self.jr.render_label_selective( + raster_flags_top, + jnp.int32(score_x), + self.score_render_y, + score_digits, + self.score_digit_masks, + start_index, + num_to_render, + spacing=self.score_digit_spacing, + max_digits_to_render=self.score_max_digits, + ) - # Render flags on the road - flag_pole_mask = self.SHAPE_MASKS["flag_pole"] - - def render_flag(carry, flag_idx): - raster = carry - flag_y = state.flags.y[flag_idx] - flag_road = state.flags.road[flag_idx] - flag_segment = state.flags.road_segment[flag_idx] - flag_collected = state.flags.collected[flag_idx] - flag_color_idx = state.flags.color_idx[flag_idx] - - flag_x = jax.lax.cond( - flag_road == 0, - lambda _: self._get_x_on_road(flag_y, flag_segment, self.consts.FIRST_TRACK_CORNERS_X), - lambda _: self._get_x_on_road(flag_y, flag_segment, self.consts.SECOND_TRACK_CORNERS_X), - operand=None, - ) - screen_y = 105 + (flag_y - state.player_car.position.y) - is_visible = jnp.logical_and( - jnp.logical_and(screen_y > 25, screen_y < 195), - ~flag_collected - ) - color_id = self.flag_palette_ids[flag_color_idx] - colored_flag_mask = jnp.where( - self.flag_solid_mask, - color_id, - self.flag_base_mask, - ) - raster = jax.lax.cond( - is_visible, - lambda r: self.jr.render_at( - self.jr.render_at(r, flag_x.astype(jnp.int32), screen_y.astype(jnp.int32), colored_flag_mask), - (flag_x + 5).astype(jnp.int32), screen_y.astype(jnp.int32), flag_pole_mask - ), - lambda r: r, - operand=raster, - ) - return raster, None - - raster_flags, _ = jax.lax.scan(render_flag, raster_score, jnp.arange(self.consts.NUM_FLAGS)) - - blackout_mask = self.SHAPE_MASKS["blackout_square"] - - def render_blackout(carry, flag_idx): - raster = carry - flag_collected = state.flags_collected_mask[flag_idx] - blackout_x = self.consts.FLAG_TOP_X_POSITIONS[flag_idx] - blackout_y = self.consts.FLAG_TOP_Y - raster = jax.lax.cond( - flag_collected, - lambda r: self.jr.render_at(r, blackout_x, blackout_y, blackout_mask), - lambda r: r, - operand=raster, - ) - return raster, None + # Render flags on the road + flag_pole_mask = self.SHAPE_MASKS["flag_pole"] + + def render_flag(carry, flag_idx): + raster = carry + flag_y = state.flags.y[flag_idx] + flag_road = state.flags.road[flag_idx] + flag_segment = state.flags.road_segment[flag_idx] + flag_collected = state.flags.collected[flag_idx] + flag_color_idx = state.flags.color_idx[flag_idx] - raster_blackout, _ = jax.lax.scan(render_blackout, raster_flags, jnp.arange(self.consts.NUM_FLAGS)) - - def render_collectible(carry, collectible_idx): - raster = carry - collectible_y = state.collectibles.y[collectible_idx] - collectible_x = state.collectibles.x[collectible_idx] - collectible_active = state.collectibles.active[collectible_idx] - collectible_color_idx = state.collectibles.color_idx[collectible_idx] - collectible_type_id = state.collectibles.type_id[collectible_idx] - screen_y = 105 + (collectible_y - state.player_car.position.y) - is_visible = jnp.logical_and( - jnp.logical_and(screen_y > 25, screen_y < 195), - collectible_active - ) + flag_x = jax.lax.cond( + flag_road == 0, + lambda _: self._get_x_on_road(flag_y, flag_segment, self.consts.FIRST_TRACK_CORNERS_X), + lambda _: self._get_x_on_road(flag_y, flag_segment, self.consts.SECOND_TRACK_CORNERS_X), + operand=None, + ) + screen_y = 105 + (flag_y - state.player_car.position.y) + is_visible = jnp.logical_and( + jnp.logical_and(screen_y > 25, screen_y < 195), + ~flag_collected + ) + color_id = self.flag_palette_ids[flag_color_idx] + colored_flag_mask = jnp.where( + self.flag_solid_mask, + color_id, + self.flag_base_mask, + ) + raster = jax.lax.cond( + is_visible, + lambda r: self.jr.render_at( + self.jr.render_at(r, flag_x.astype(jnp.int32), screen_y.astype(jnp.int32), colored_flag_mask), + (flag_x + 5).astype(jnp.int32), screen_y.astype(jnp.int32), flag_pole_mask + ), + lambda r: r, + operand=raster, + ) + return raster, None + + raster_flags, _ = jax.lax.scan(render_flag, raster_score, jnp.arange(self.consts.NUM_FLAGS)) + + blackout_mask = self.SHAPE_MASKS["blackout_square"] + + def render_blackout(carry, flag_idx): + raster = carry + flag_collected = state.flags_collected_mask[flag_idx] + blackout_x = self.consts.FLAG_TOP_X_POSITIONS[flag_idx] + blackout_y = self.consts.FLAG_TOP_Y + raster = jax.lax.cond( + flag_collected, + lambda r: self.jr.render_at(r, blackout_x, blackout_y, blackout_mask), + lambda r: r, + operand=raster, + ) + return raster, None + + raster_blackout, _ = jax.lax.scan(render_blackout, raster_flags, jnp.arange(self.consts.NUM_FLAGS)) + + def render_collectible(carry, collectible_idx): + raster = carry + collectible_y = state.collectibles.y[collectible_idx] + collectible_x = state.collectibles.x[collectible_idx] + collectible_active = state.collectibles.active[collectible_idx] + collectible_color_idx = state.collectibles.color_idx[collectible_idx] + collectible_type_id = state.collectibles.type_id[collectible_idx] + screen_y = 105 + (collectible_y - state.player_car.position.y) + is_visible = jnp.logical_and( + jnp.logical_and(screen_y > 25, screen_y < 195), + collectible_active + ) - def get_sprite_and_mask(type_id): - cherry_result = (self.cherry_base_mask, self.cherry_solid_mask, self.cherry_palette_ids) - balloon_result = (self.balloon_base_mask, self.balloon_solid_mask, self.balloon_palette_ids) - lollypop_result = (self.lollypop_base_mask, self.lollypop_solid_mask, self.lollypop_palette_ids) - ice_cream_result = (self.ice_cream_base_mask, self.ice_cream_solid_mask, self.ice_cream_palette_ids) - return jax.lax.cond( - type_id == self.consts.COLLECTIBLE_TYPE_CHERRY, - lambda _: cherry_result, + def get_sprite_and_mask(type_id): + cherry_result = (self.cherry_base_mask, self.cherry_solid_mask, self.cherry_palette_ids) + balloon_result = (self.balloon_base_mask, self.balloon_solid_mask, self.balloon_palette_ids) + lollypop_result = (self.lollypop_base_mask, self.lollypop_solid_mask, self.lollypop_palette_ids) + ice_cream_result = (self.ice_cream_base_mask, self.ice_cream_solid_mask, self.ice_cream_palette_ids) + return jax.lax.cond( + type_id == self.consts.COLLECTIBLE_TYPE_CHERRY, + lambda _: cherry_result, + lambda _: jax.lax.cond( + type_id == self.consts.COLLECTIBLE_TYPE_BALLOON, + lambda _: balloon_result, lambda _: jax.lax.cond( - type_id == self.consts.COLLECTIBLE_TYPE_BALLOON, - lambda _: balloon_result, - lambda _: jax.lax.cond( - type_id == self.consts.COLLECTIBLE_TYPE_LOLLYPOP, - lambda _: lollypop_result, - lambda _: ice_cream_result, - operand=None, - ), + type_id == self.consts.COLLECTIBLE_TYPE_LOLLYPOP, + lambda _: lollypop_result, + lambda _: ice_cream_result, operand=None, ), operand=None, - ) - - base_mask, solid_mask, palette_ids = get_sprite_and_mask(collectible_type_id) - color_id = palette_ids[collectible_color_idx] - colored_mask = jnp.where( - (base_mask != self.jr.TRANSPARENT_ID) & (base_mask != 0), - color_id, - base_mask, - ) - raster = jax.lax.cond( - is_visible, - lambda r: self.jr.render_at(r, collectible_x.astype(jnp.int32), screen_y.astype(jnp.int32), colored_mask), - lambda r: r, - operand=raster, + ), + operand=None, ) - return raster, None - - raster_collectibles, _ = jax.lax.scan(render_collectible, raster_blackout, jnp.arange(self.consts.MAX_COLLECTIBLES)) - all_lives_bottom_mask = self.SHAPE_MASKS["all_lives_bottom"] - raster_lives = self.jr.render_at(raster_collectibles, 10, 195, all_lives_bottom_mask) - - wall_bottom_mask = self.SHAPE_MASKS["tempPointer"] - raster_pointer = self.jr.render_at(raster_lives, 140, 25, wall_bottom_mask) - - return self.jr.render_from_palette(raster_pointer, self.PALETTE) + base_mask, solid_mask, palette_ids = get_sprite_and_mask(collectible_type_id) + color_id = palette_ids[collectible_color_idx] + colored_mask = jnp.where( + (base_mask != self.jr.TRANSPARENT_ID) & (base_mask != 0), + color_id, + base_mask, + ) + raster = jax.lax.cond( + is_visible, + lambda r: self.jr.render_at(r, collectible_x.astype(jnp.int32), screen_y.astype(jnp.int32), colored_mask), + lambda r: r, + operand=raster, + ) + return raster, None + + raster_collectibles, _ = jax.lax.scan(render_collectible, raster_blackout, jnp.arange(self.consts.MAX_COLLECTIBLES)) + + all_lives_bottom_mask = self.SHAPE_MASKS["all_lives_bottom"] + raster_lives = self.jr.render_at(raster_collectibles, 10, 195, all_lives_bottom_mask) + + # Black out lost lives (similar to flag blackout) + blackout_mask = self.SHAPE_MASKS["blackout_square"] + lives_lost = self.consts.INITIAL_LIVES - state.lives + + def render_life_blackout(carry, life_idx): + raster = carry + # Black out this life if it has been lost (life_idx < lives_lost) + should_blackout = life_idx < lives_lost + blackout_x = self.consts.LIFE_BOTTOM_X_POSITIONS[life_idx] + blackout_y = self.consts.LIFE_BOTTOM_Y + raster = jax.lax.cond( + should_blackout, + lambda r: self.jr.render_at(r, blackout_x, blackout_y, blackout_mask), + lambda r: r, + operand=raster, + ) + return raster, None + + raster_lives_blackout, _ = jax.lax.scan(render_life_blackout, raster_lives, jnp.arange(self.consts.INITIAL_LIVES)) - base_scene = jax.lax.cond( - hide_world, - lambda _: render_roads_only(), - lambda _: render_full_scene(), - operand=None, - ) + wall_bottom_mask = self.SHAPE_MASKS["tempPointer"] + raster_pointer = self.jr.render_at(raster_lives_blackout, 140, 25, wall_bottom_mask) - return jax.lax.cond( - hide_world, - lambda _: self.jr.render_from_palette(base_scene, self.PALETTE), - lambda _: render_rest(base_scene), - operand=None, - ) \ No newline at end of file + return self.jr.render_from_palette(raster_pointer, self.PALETTE) \ No newline at end of file From 5afcd7f95339da5c80c56276bcc93fc7a7a94cfc Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sat, 20 Dec 2025 22:30:52 +0100 Subject: [PATCH 58/76] cleanup code --- src/jaxatari/games/jax_upndown.py | 288 ++++++++++++++---------------- 1 file changed, 134 insertions(+), 154 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 52e5cd88a..4933e8cbf 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -1,8 +1,8 @@ -from jax._src.pjit import JitWrapped import os import math from functools import partial from typing import NamedTuple, Tuple + import jax import jax.lax import jax.numpy as jnp @@ -14,9 +14,8 @@ from jaxatari.environment import JaxEnvironment, JAXAtariAction as Action class UpNDownConstants(NamedTuple): - FRAME_SKIP: int = 4 + FRAME_SKIP: int = 4 # Used by AtariWrapper DIFFICULTIES: chex.Array = jnp.array([0, 1, 2, 3, 4, 5]) - ACTION_REPEAT_PROBS: float = 0.25 MAX_SPEED: int = 6 INITIAL_LIVES: int = 5 JUMP_ARC_HEIGHT: float = 18.0 @@ -42,8 +41,12 @@ class UpNDownConstants(NamedTuple): LANDING_TOLERANCE: int = 15 # Pixels tolerance for landing on a road (increased by 5 for off-road landings) LATE_JUMP_COLLISION_FRAMES: int = 2 LANDING_COLLISION_DISTANCE: float = 8.0 # Larger collision distance when landing (for crossings) + GROUND_COLLISION_DISTANCE: float = 3.0 # Tight collision distance for ground collisions LATE_JUMP_ENEMY_SCORE: int = 400 STEEP_ROAD_SPEED_REDUCTION_INTERVAL: int = 8 # Frames between each speed reduction on steep roads + PASSIVE_SCORE_INTERVAL: int = 60 # Steps between passive score awards + PASSIVE_SCORE_AMOUNT: int = 10 # Points awarded for passive scoring + COLLISION_THRESHOLD: float = 5.0 # Distance threshold for flag/collectible collision TRACK_LENGTH: int = 1036 FIRST_TRACK_CORNERS_X: chex.Array = jnp.array([30, 75, 128, 75, 21, 75, 131, 111, 150, 95, 150, 115, 150, 108, 150, 115, 115, 75, 18, 38, 67, 38, 38, 20, 64, 30]) TRACK_CORNERS_Y: chex.Array = jnp.array([0, -40, -98, -155, -203, -268, -327, -347, -382, -467, -525, -565, -597, -625, -670, -705, -738, -788, -838, -862, -898, -925, -950, -972, -1000, -1035]) @@ -52,8 +55,6 @@ class UpNDownConstants(NamedTuple): INITIAL_ROAD_POS_Y: int = 25 # Flag constants - 8 flags with different colors matching the top row NUM_FLAGS: int = 8 - FLAG_SIZE: Tuple[int, int] = (11, 6) # height, width of the flag sprite - FLAG_POLE_SIZE: Tuple[int, int] = (7, 2) # height, width of the pole sprite # Flag colors as RGBA values (matching the top row from left to right) FLAG_COLORS: chex.Array = jnp.array([ [184, 50, 50, 255], # Red @@ -73,14 +74,8 @@ class UpNDownConstants(NamedTuple): # Life display constants - positions of life cars at the bottom LIFE_BOTTOM_X_POSITIONS: chex.Array = jnp.array([13, 18, 25, 33, 33]) # X positions for 5 life cars LIFE_BOTTOM_Y: int = 195 - LIFE_BLACKOUT_SIZE: Tuple[int, int] = (14, 14) # Size of blackout square for lives - PICKUP_SCORE: int = 100 # Points awarded for jumping on a pickup truck - FLAG_CARRIER_SCORE: int = 125 # Points awarded for jumping on a flag carrier - CAMARO_SCORE: int = 150 # Points awarded for jumping on a camaro - TRUCK_SCORE: int = 175 # Points awarded for jumping on a truck # Collectible constants - unified dynamic spawning MAX_COLLECTIBLES: int = 2 # Maximum collectibles that can exist at once (pool of mixed types) - COLLECTIBLE_SIZE: Tuple[int, int] = (8, 8) # height, width of collectible sprite COLLECTIBLE_SPAWN_INTERVAL: int = 200 # Steps between spawn attempts COLLECTIBLE_DESPAWN_DISTANCE: int = 500 # Distance beyond which collectibles despawn # Collectible types (indices for type field) @@ -88,8 +83,8 @@ class UpNDownConstants(NamedTuple): COLLECTIBLE_TYPE_BALLOON: int = 1 COLLECTIBLE_TYPE_LOLLYPOP: int = 2 COLLECTIBLE_TYPE_ICE_CREAM: int = 3 - # Collectible type spawn probabilities (must sum to 100) - COLLECTIBLE_SPAWN_PROBABILITIES: chex.Array = jnp.array([35, 30, 25, 10], dtype=jnp.int32) # Cherry: 40%, Balloon: 20%, Lollypop: 20%, IceCream: 20% + # Collectible type spawn probabilities (cumulative thresholds for random sampling) + COLLECTIBLE_SPAWN_PROBABILITIES: chex.Array = jnp.array([35, 65, 90, 100], dtype=jnp.int32) # Cherry: 35%, Balloon: 30%, Lollypop: 25%, IceCream: 10% # Collectible type scores COLLECTIBLE_SCORES: chex.Array = jnp.array([50, 65, 70, 75], dtype=jnp.int32) # [cherry, balloon, lollypop, ice_cream] # Shared collectible colors @@ -201,6 +196,29 @@ def __init__(self, consts: UpNDownConstants = None, reward_funcs: list[callable] Action.DOWNFIRE, ] self.obs_size = 3*4+1+1 + # Speed dividers for movement timing (indexed by speed level) + self._speed_dividers = jnp.array([0, 1, 2, 4, 8, 16, 16]) + + @partial(jax.jit, static_argnums=(0,)) + def _compute_movement_timing(self, speed: chex.Array, step_counter: chex.Array) -> Tuple[chex.Array, chex.Array, chex.Array, chex.Array]: + """Calculate movement timing parameters based on speed. + + Returns: + Tuple of (move_y, move_x, step_size, speed_sign) + """ + abs_speed = jnp.abs(speed) + speed_index = jnp.minimum(abs_speed, jnp.int32(self._speed_dividers.shape[0] - 1)) + speed_divider = self._speed_dividers[speed_index] + effective_divider = jnp.maximum(1, speed_divider) + period = jnp.maximum(1, 16 // effective_divider) + half_period = jnp.maximum(1, period // 2) + speed_sign = jnp.sign(speed).astype(jnp.float32) + + move_y = jnp.logical_and((step_counter % period) == (half_period % period), speed != 0) + move_x = jnp.logical_and((step_counter % period) == 0, speed != 0) + step_size = jnp.where(speed_index >= self.consts.MAX_SPEED, 1.5, 1.0) + + return move_y, move_x, step_size, speed_sign @partial(jax.jit, static_argnums=(0,)) def _get_slope_and_intercept_from_indices(self, current_road: chex.Array, road_index_A: chex.Array, road_index_B: chex.Array) -> Tuple[chex.Array, chex.Array]: @@ -219,14 +237,6 @@ def _get_slope_and_intercept_from_indices(self, current_road: chex.Array, road_i b = tracky[road_index] - slope * trackx[road_index] return slope, b - @partial(jax.jit, static_argnums=(0,)) - def _getSlopeAndB(self, state: UpNDownState) -> chex.Array: - return self._get_slope_and_intercept_from_indices( - state.player_car.current_road, - state.player_car.road_index_A, - state.player_car.road_index_B, - ) - @partial(jax.jit, static_argnums=(0,)) def _is_on_line_for_position(self, position: EntityPosition, slope: chex.Array, b: chex.Array, player_speed: chex.Array, turn: chex.Array) -> chex.Array: x_step = abs(jnp.subtract(position.y, slope * (position.x) + b)) @@ -279,30 +289,46 @@ def _is_steep_road_segment(self, current_road: chex.Array, road_index_A: chex.Ar ) # A segment is steep if there's no X change (or very small change) return x_diff < 1.0 - - @partial(jax.jit, static_argnums=(0,)) - def _isOnLine(self, state: UpNDownState, player_speed: chex.Array, turn: chex.Array) -> chex.Array: - slope, b = self._getSlopeAndB(state) - return self._is_on_line_for_position(state.player_car.position, slope, b, player_speed, turn) - - @partial(jax.jit, static_argnums=(0,)) - def _landing_in_water(self, state: UpNDownState, new_position_x: chex.Array, new_position_y: chex.Array) -> chex.Array: - road_A_x = ((new_position_y - self.consts.TRACK_CORNERS_Y[state.player_car.road_index_A]) / (self.consts.TRACK_CORNERS_Y[state.player_car.road_index_A+1] - self.consts.TRACK_CORNERS_Y[state.player_car.road_index_A])) * (self.consts.FIRST_TRACK_CORNERS_X[state.player_car.road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_X[state.player_car.road_index_A]) + self.consts.FIRST_TRACK_CORNERS_X[state.player_car.road_index_A] - road_B_x = ((new_position_y - self.consts.TRACK_CORNERS_Y[state.player_car.road_index_B]) / (self.consts.TRACK_CORNERS_Y[state.player_car.road_index_B+1] - self.consts.TRACK_CORNERS_Y[state.player_car.road_index_B])) * (self.consts.SECOND_TRACK_CORNERS_X[state.player_car.road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[state.player_car.road_index_B]) + self.consts.SECOND_TRACK_CORNERS_X[state.player_car.road_index_B] - distance_to_road_A = jnp.abs(new_position_x - road_A_x) - distance_to_road_B = jnp.abs(new_position_x - road_B_x) - landing_in_Water = jnp.logical_and(distance_to_road_A > self.consts.LANDING_TOLERANCE, distance_to_road_B > self.consts.LANDING_TOLERANCE) - between_roads = jnp.logical_and(new_position_x > jnp.minimum(road_A_x, road_B_x), new_position_x < jnp.maximum(road_A_x, road_B_x)) - return landing_in_Water, between_roads, road_A_x, road_B_x @partial(jax.jit, static_argnums=(0,)) - def _landing_in_water_for_indices(self, road_index_A: chex.Array, road_index_B: chex.Array, new_position_x: chex.Array, new_position_y: chex.Array) -> chex.Array: - road_A_x = ((new_position_y - self.consts.TRACK_CORNERS_Y[road_index_A]) / (self.consts.TRACK_CORNERS_Y[road_index_A+1] - self.consts.TRACK_CORNERS_Y[road_index_A])) * (self.consts.FIRST_TRACK_CORNERS_X[road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_X[road_index_A]) + self.consts.FIRST_TRACK_CORNERS_X[road_index_A] - road_B_x = ((new_position_y - self.consts.TRACK_CORNERS_Y[road_index_B]) / (self.consts.TRACK_CORNERS_Y[road_index_B+1] - self.consts.TRACK_CORNERS_Y[road_index_B])) * (self.consts.SECOND_TRACK_CORNERS_X[road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[road_index_B]) + self.consts.SECOND_TRACK_CORNERS_X[road_index_B] + def _check_landing_position( + self, + road_index_A: chex.Array, + road_index_B: chex.Array, + new_position_x: chex.Array, + new_position_y: chex.Array, + ) -> Tuple[chex.Array, chex.Array, chex.Array, chex.Array]: + """Check if a position is valid for landing (on or between roads). + + Returns: + Tuple of (landing_in_water, between_roads, road_A_x, road_B_x) + """ + # Calculate X position on road A at the given Y + y_ratio_A = (new_position_y - self.consts.TRACK_CORNERS_Y[road_index_A]) / ( + self.consts.TRACK_CORNERS_Y[road_index_A + 1] - self.consts.TRACK_CORNERS_Y[road_index_A] + ) + road_A_x = y_ratio_A * ( + self.consts.FIRST_TRACK_CORNERS_X[road_index_A + 1] - self.consts.FIRST_TRACK_CORNERS_X[road_index_A] + ) + self.consts.FIRST_TRACK_CORNERS_X[road_index_A] + + # Calculate X position on road B at the given Y + y_ratio_B = (new_position_y - self.consts.TRACK_CORNERS_Y[road_index_B]) / ( + self.consts.TRACK_CORNERS_Y[road_index_B + 1] - self.consts.TRACK_CORNERS_Y[road_index_B] + ) + road_B_x = y_ratio_B * ( + self.consts.SECOND_TRACK_CORNERS_X[road_index_B + 1] - self.consts.SECOND_TRACK_CORNERS_X[road_index_B] + ) + self.consts.SECOND_TRACK_CORNERS_X[road_index_B] + distance_to_road_A = jnp.abs(new_position_x - road_A_x) distance_to_road_B = jnp.abs(new_position_x - road_B_x) - landing_in_water = jnp.logical_and(distance_to_road_A > self.consts.LANDING_TOLERANCE, distance_to_road_B > self.consts.LANDING_TOLERANCE) - between_roads = jnp.logical_and(new_position_x > jnp.minimum(road_A_x, road_B_x), new_position_x < jnp.maximum(road_A_x, road_B_x)) + landing_in_water = jnp.logical_and( + distance_to_road_A > self.consts.LANDING_TOLERANCE, + distance_to_road_B > self.consts.LANDING_TOLERANCE, + ) + between_roads = jnp.logical_and( + new_position_x > jnp.minimum(road_A_x, road_B_x), + new_position_x < jnp.maximum(road_A_x, road_B_x), + ) return landing_in_water, between_roads, road_A_x, road_B_x @partial(jax.jit, static_argnums=(0,)) @@ -315,7 +341,6 @@ def _advance_player_car( current_road: chex.Array, speed: chex.Array, is_jumping: chex.Array, - is_on_road: chex.Array, step_counter: chex.Array, width: chex.Array, height: chex.Array, @@ -333,15 +358,8 @@ def _advance_player_car( - If between roads: snap to nearest road - If too far from both roads (outside the road area): crash (water) """ - # Speed-based movement timing - dividers = jnp.array([0, 1, 2, 4, 8, 16, 16]) - abs_speed = jnp.abs(speed) - speed_index = jnp.minimum(abs_speed, jnp.int32(dividers.shape[0] - 1)) - speed_divider = dividers[speed_index] - effective_divider = jnp.maximum(1, speed_divider) - period = jnp.maximum(1, 16 // effective_divider) - half_period = jnp.maximum(1, period // 2) - speed_sign = jax.lax.cond(speed != 0, lambda _: jax.lax.abs(speed) / speed, lambda _: jnp.array(0.0), operand=None) + # Calculate movement timing using helper + move_y, move_x, step_size, speed_sign = self._compute_movement_timing(speed, step_counter) # Get slope and intercept for current road slope, b = self._get_slope_and_intercept_from_indices(current_road, road_index_A, road_index_B) @@ -353,14 +371,8 @@ def _advance_player_car( lambda _: self.consts.SECOND_TRACK_CORNERS_X[road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[road_index_B], operand=None, ) - car_direction_x = jax.lax.cond(direction_raw > 0, lambda _: 1, lambda _: -1, operand=None) - - # Movement timing flags - move_y = jnp.logical_and((step_counter % period) == (half_period % period), speed != 0) - move_x = jnp.logical_and((step_counter % period) == 0, speed != 0) - - # Step size (slightly larger at max speed) - step_size = jnp.where(speed_index >= self.consts.MAX_SPEED, 1.5, 1.0) + # Use sign, default to -1 for zero (vertical segments) + car_direction_x = jnp.where(direction_raw == 0, -1, jnp.sign(direction_raw)).astype(jnp.int32) position = EntityPosition(x=position_x, y=position_y, width=width, height=height) @@ -532,14 +544,8 @@ def _advance_car_core( car_type: chex.Array, ) -> Car: """Simplified car advancement for enemy cars (no jumping/landing logic).""" - dividers = jnp.array([0, 1, 2, 4, 8, 16, 16]) - abs_speed = jnp.abs(speed) - speed_index = jnp.minimum(abs_speed, jnp.int32(dividers.shape[0] - 1)) - speed_divider = dividers[speed_index] - effective_divider = jnp.maximum(1, speed_divider) - period = jnp.maximum(1, 16 // effective_divider) - half_period = jnp.maximum(1, period // 2) - speed_sign = jax.lax.cond(speed != 0, lambda _: jax.lax.abs(speed) / speed, lambda _: jnp.array(0.0), operand=None) + # Calculate movement timing using helper + move_y, move_x, step_size, speed_sign = self._compute_movement_timing(speed, step_counter) slope, b = self._get_slope_and_intercept_from_indices(current_road, road_index_A, road_index_B) @@ -549,12 +555,8 @@ def _advance_car_core( lambda _: self.consts.SECOND_TRACK_CORNERS_X[road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[road_index_B], operand=None, ) - car_direction_x = jax.lax.cond(direction_raw > 0, lambda _: 1, lambda _: -1, operand=None) - - move_y = jnp.logical_and((step_counter % period) == (half_period % period), speed != 0) - move_x = jnp.logical_and((step_counter % period) == 0, speed != 0) - - step_size = jnp.where(speed_index >= self.consts.MAX_SPEED, 1.5, 1.0) + # Use sign, default to -1 for zero (vertical segments) + car_direction_x = jnp.where(direction_raw == 0, -1, jnp.sign(direction_raw)).astype(jnp.int32) position = EntityPosition(x=position_x, y=position_y, width=width, height=height) @@ -653,7 +655,7 @@ def check_flag_collision(flag_idx): ) collision = jnp.logical_and( - jnp.logical_and(y_distance < 5, x_distance < 5), #change the distance threshold if needed + jnp.logical_and(y_distance < self.consts.COLLISION_THRESHOLD, x_distance < self.consts.COLLISION_THRESHOLD), jnp.logical_and(same_road, ~flag_collected) ) return collision @@ -720,34 +722,19 @@ def find_inactive_idx(collectibles_in): # Generate random spawn position using fold_in for deterministic randomness base_key = jax.random.PRNGKey(0) key_for_spawn = jax.random.fold_in(base_key, state.step_counter) - key1, key2, key3, key4, key5 = jax.random.split(key_for_spawn, 5) + key1, key2, key3, key4 = jax.random.split(key_for_spawn, 4) y_spawn = jax.random.uniform(key1, minval=-900.0, maxval=-100.0) road_spawn = jnp.array(jax.random.randint(key2, shape=(), minval=0, maxval=2), dtype=jnp.int32) color_spawn = jnp.array(jax.random.randint(key3, shape=(), minval=0, maxval=len(self.consts.COLLECTIBLE_COLORS)), dtype=jnp.int32) - # Randomly select collectible type based on spawn probabilities - # Convert probabilities (%) to cumulative distribution for sampling + # Randomly select collectible type using cumulative probability thresholds + # COLLECTIBLE_SPAWN_PROBABILITIES contains cumulative values: [35, 65, 90, 100] + # Cherry: [0-35), Balloon: [35-65), Lollypop: [65-90), IceCream: [90-100] rand_type = jax.random.uniform(key4, minval=0.0, maxval=100.0) - # Use cumulative probabilities: cherry [0-40], balloon [40-60], lollypop [60-80], ice_cream [80-100] - def select_type(rand_val): - # Returns 0=cherry, 1=balloon, 2=lollypop, 3=ice_cream - type_id = jnp.where( - rand_val < self.consts.COLLECTIBLE_SPAWN_PROBABILITIES[0], - jnp.int32(self.consts.COLLECTIBLE_TYPE_CHERRY), - jnp.where( - rand_val < self.consts.COLLECTIBLE_SPAWN_PROBABILITIES[1], - jnp.int32(self.consts.COLLECTIBLE_TYPE_BALLOON), - jnp.where( - rand_val < self.consts.COLLECTIBLE_SPAWN_PROBABILITIES[2], - jnp.int32(self.consts.COLLECTIBLE_TYPE_LOLLYPOP), - jnp.int32(self.consts.COLLECTIBLE_TYPE_ICE_CREAM) - ) - ) - ) - return type_id - - type_id_spawn = select_type(rand_type) + # Use searchsorted for efficient threshold lookup + type_id_spawn = jnp.searchsorted(self.consts.COLLECTIBLE_SPAWN_PROBABILITIES, rand_type, side='right') + type_id_spawn = jnp.clip(type_id_spawn, 0, 3).astype(jnp.int32) # Calculate X position on road segment_spawn = self._get_road_segment(y_spawn) @@ -761,34 +748,32 @@ def select_type(rand_val): # Create mask for which collectibles to update update_mask = (jnp.arange(self.consts.MAX_COLLECTIBLES) == spawn_idx) & should_spawn & has_inactive_slot - # Update collectibles with proper masking - updated_collectibles = Collectible( - y=jnp.where(update_mask, y_spawn, state.collectibles.y), - x=jnp.where(update_mask, x_spawn, state.collectibles.x), - road=jnp.where(update_mask, road_spawn, state.collectibles.road), - color_idx=jnp.where(update_mask, color_spawn, state.collectibles.color_idx), - type_id=jnp.where(update_mask, type_id_spawn, state.collectibles.type_id), - active=jnp.where(update_mask, True, state.collectibles.active), - ) + # Update collectibles with proper masking - spawn new items + spawned_y = jnp.where(update_mask, y_spawn, state.collectibles.y) + spawned_x = jnp.where(update_mask, x_spawn, state.collectibles.x) + spawned_road = jnp.where(update_mask, road_spawn, state.collectibles.road) + spawned_color_idx = jnp.where(update_mask, color_spawn, state.collectibles.color_idx) + spawned_type_id = jnp.where(update_mask, type_id_spawn, state.collectibles.type_id) + spawned_active = jnp.where(update_mask, True, state.collectibles.active) # Despawn logic - remove collectibles too far from player def check_despawn(idx): - c_y = updated_collectibles.y[idx] - c_active = updated_collectibles.active[idx] + c_y = spawned_y[idx] + c_active = spawned_active[idx] distance = jnp.abs(new_player_y - c_y) too_far = distance > self.consts.COLLECTIBLE_DESPAWN_DISTANCE should_despawn = jnp.logical_and(c_active, too_far) return should_despawn despawn_mask = jax.vmap(check_despawn)(jnp.arange(self.consts.MAX_COLLECTIBLES)) - new_active = jnp.logical_and(updated_collectibles.active, ~despawn_mask) + active_after_despawn = jnp.logical_and(spawned_active, ~despawn_mask) # Collision detection def check_collision(idx): - c_y = updated_collectibles.y[idx] - c_x = updated_collectibles.x[idx] - c_road = updated_collectibles.road[idx] - c_active = updated_collectibles.active[idx] + c_y = spawned_y[idx] + c_x = spawned_x[idx] + c_road = spawned_road[idx] + c_active = spawned_active[idx] y_distance = jnp.abs(new_player_y - c_y) x_distance = jnp.abs(player_x - c_x) @@ -798,7 +783,7 @@ def check_collision(idx): ) collision = jnp.logical_and( - jnp.logical_and(y_distance < 5, x_distance < 5), + jnp.logical_and(y_distance < self.consts.COLLISION_THRESHOLD, x_distance < self.consts.COLLISION_THRESHOLD), jnp.logical_and(same_road, c_active) ) return collision @@ -806,12 +791,12 @@ def check_collision(idx): collections = jax.vmap(check_collision)(jnp.arange(self.consts.MAX_COLLECTIBLES)) # Deactivate collected items - new_active = jnp.logical_and(new_active, ~collections) + final_active = jnp.logical_and(active_after_despawn, ~collections) # Update score - use type_id to look up score value def get_collection_score(idx): is_collected = collections[idx] - type_id = updated_collectibles.type_id[idx] + type_id = spawned_type_id[idx] # Look up score based on type_id using array indexing score = self.consts.COLLECTIBLE_SCORES[type_id] return jnp.where(is_collected, score, 0) @@ -819,13 +804,14 @@ def get_collection_score(idx): score_array = jax.vmap(get_collection_score)(jnp.arange(self.consts.MAX_COLLECTIBLES)) score_delta = jnp.sum(score_array) + # Create final collectibles state updated_collectibles = Collectible( - y=updated_collectibles.y, - x=updated_collectibles.x, - road=updated_collectibles.road, - color_idx=updated_collectibles.color_idx, - type_id=updated_collectibles.type_id, - active=new_active, + y=spawned_y, + x=spawned_x, + road=spawned_road, + color_idx=spawned_color_idx, + type_id=spawned_type_id, + active=final_active, ) return updated_collectibles, score_delta, new_collectible_timer @@ -1038,7 +1024,6 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: current_road=state.player_car.current_road, speed=player_speed, is_jumping=is_jumping, - is_on_road=is_on_road, step_counter=state.step_counter, width=state.player_car.position.width, height=state.player_car.position.height, @@ -1383,10 +1368,9 @@ def _enemy_collision_step_main(self, state: UpNDownState) -> UpNDownState: dy = jnp.abs(state.enemy_cars.position.y - player_y) wrapped_dy = jnp.minimum(dy, self.consts.TRACK_LENGTH - dy) - # For ground collision: only trigger when enemy position is within 3 pixels - ground_collision_distance = 3.0 - overlap_x_ground = dx <= ground_collision_distance - overlap_y_ground = wrapped_dy <= ground_collision_distance + # For ground collision: only trigger when enemy position is within tight distance + overlap_x_ground = dx <= self.consts.GROUND_COLLISION_DISTANCE + overlap_y_ground = wrapped_dy <= self.consts.GROUND_COLLISION_DISTANCE # For landing collision: use larger distance and road-independent (for crossings) overlap_x_landing = dx <= self.consts.LANDING_COLLISION_DISTANCE overlap_y_landing = wrapped_dy <= self.consts.LANDING_COLLISION_DISTANCE @@ -1454,10 +1438,13 @@ def handle_ground_collision(): ) def _passive_score_step_main(self, state: UpNDownState) -> UpNDownState: - """Award passive score every 60 steps after the player has started moving.""" + """Award passive score at regular intervals after the player has started moving.""" bonus = jax.lax.cond( - jnp.logical_and(state.round_started, state.movement_steps % 60 == 0), - lambda _: jnp.int32(10), + jnp.logical_and( + state.round_started, + state.movement_steps % self.consts.PASSIVE_SCORE_INTERVAL == 0, + ), + lambda _: jnp.int32(self.consts.PASSIVE_SCORE_AMOUNT), lambda _: jnp.int32(0), operand=None, ) @@ -1695,22 +1682,21 @@ def _pad_mask(mask): self.flag_solid_mask = self.flag_base_mask != self.jr.TRANSPARENT_ID self.flag_palette_ids = self._compute_flag_palette_ids() - # Precompute collectible mask data for recoloring (unified for all types: cherry, balloon, lollypop, ice cream) + # Precompute collectible mask data for recoloring (unified for all types) + # Reuse the same palette IDs since all collectibles use FLAG_COLORS + self.collectible_palette_ids = self.flag_palette_ids + self.cherry_base_mask = self.SHAPE_MASKS["cherry"] self.cherry_solid_mask = self.cherry_base_mask != self.jr.TRANSPARENT_ID - self.cherry_palette_ids = self._compute_flag_palette_ids() self.balloon_base_mask = self.SHAPE_MASKS["balloon"] self.balloon_solid_mask = self.balloon_base_mask != self.jr.TRANSPARENT_ID - self.balloon_palette_ids = self._compute_flag_palette_ids() self.lollypop_base_mask = self.SHAPE_MASKS["lollypop"] self.lollypop_solid_mask = self.lollypop_base_mask != self.jr.TRANSPARENT_ID - self.lollypop_palette_ids = self._compute_flag_palette_ids() self.ice_cream_base_mask = self.SHAPE_MASKS["ice_cream"] self.ice_cream_solid_mask = self.ice_cream_base_mask != self.jr.TRANSPARENT_ID - self.ice_cream_palette_ids = self._compute_flag_palette_ids() # Score rendering helpers self.score_digit_masks = self.SHAPE_MASKS["score_digits"] @@ -1736,7 +1722,6 @@ def _get_road_sprite_sizes(self, road_files: list[str]) -> list: sprite = jnp.load(f"{road_dir}/{sprite_name}") sizes.append(sprite.shape[0]) complete_size = int(sum(sizes)) - jax.debug.print("Complete road size: {}", complete_size) return sizes, complete_size def _find_palette_id(self, rgba: jnp.ndarray) -> int: @@ -2001,25 +1986,20 @@ def render_collectible(carry, collectible_idx): ) def get_sprite_and_mask(type_id): - cherry_result = (self.cherry_base_mask, self.cherry_solid_mask, self.cherry_palette_ids) - balloon_result = (self.balloon_base_mask, self.balloon_solid_mask, self.balloon_palette_ids) - lollypop_result = (self.lollypop_base_mask, self.lollypop_solid_mask, self.lollypop_palette_ids) - ice_cream_result = (self.ice_cream_base_mask, self.ice_cream_solid_mask, self.ice_cream_palette_ids) - return jax.lax.cond( - type_id == self.consts.COLLECTIBLE_TYPE_CHERRY, - lambda _: cherry_result, - lambda _: jax.lax.cond( - type_id == self.consts.COLLECTIBLE_TYPE_BALLOON, - lambda _: balloon_result, - lambda _: jax.lax.cond( - type_id == self.consts.COLLECTIBLE_TYPE_LOLLYPOP, - lambda _: lollypop_result, - lambda _: ice_cream_result, - operand=None, - ), - operand=None, - ), - operand=None, + # Use switch for O(1) lookup instead of nested conditionals + def get_cherry(_): + return (self.cherry_base_mask, self.cherry_solid_mask, self.collectible_palette_ids) + def get_balloon(_): + return (self.balloon_base_mask, self.balloon_solid_mask, self.collectible_palette_ids) + def get_lollypop(_): + return (self.lollypop_base_mask, self.lollypop_solid_mask, self.collectible_palette_ids) + def get_ice_cream(_): + return (self.ice_cream_base_mask, self.ice_cream_solid_mask, self.collectible_palette_ids) + + return jax.lax.switch( + type_id, + [get_cherry, get_balloon, get_lollypop, get_ice_cream], + None, ) base_mask, solid_mask, palette_ids = get_sprite_and_mask(collectible_type_id) From c82b80eb8c083dfebe3e696228bcc9cfe903656a Mon Sep 17 00:00:00 2001 From: shaik05 Date: Sun, 21 Dec 2025 11:06:49 +0100 Subject: [PATCH 59/76] modified respawn --- src/jaxatari/games/jax_upndown.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 4933e8cbf..311b0b47d 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -19,6 +19,10 @@ class UpNDownConstants(NamedTuple): MAX_SPEED: int = 6 INITIAL_LIVES: int = 5 JUMP_ARC_HEIGHT: float = 18.0 + RESPAWN_DELAY_FRAMES: int = 60 + RESPAWN_Y: int = 0 + RESPAWN_X: int = 30 + ALL_FLAGS_BONUS: int = 1000 # Enemy spawning and movement MAX_ENEMY_CARS: int = 8 ENEMY_SPAWN_INTERVAL: int = 80 @@ -144,7 +148,6 @@ class EnemyCars(NamedTuple): class UpNDownState(NamedTuple): score: chex.Array - lives: chex.Array difficulty: chex.Array jump_cooldown: chex.Array post_jump_cooldown: chex.Array @@ -1332,6 +1335,8 @@ def _respawn_after_collision(self, state: UpNDownState, new_lives: chex.Array) - return UpNDownState( score=state.score, lives=new_lives, + is_dead=jnp.array(False), + respawn_timer=jnp.array(0, dtype=jnp.int32), difficulty=state.difficulty, jump_cooldown=jnp.array(0, dtype=jnp.int32), post_jump_cooldown=jnp.array(0, dtype=jnp.int32), @@ -1491,6 +1496,8 @@ def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: state = UpNDownState( score=0, lives=jnp.array(self.consts.INITIAL_LIVES, dtype=jnp.int32), + is_dead=jnp.array(False), + respawn_timer=jnp.array(0, dtype=jnp.int32), difficulty=self.consts.DIFFICULTIES[0], jump_cooldown=0, post_jump_cooldown=0, @@ -2007,6 +2014,7 @@ def get_ice_cream(_): colored_mask = jnp.where( (base_mask != self.jr.TRANSPARENT_ID) & (base_mask != 0), color_id, + base_mask, ) raster = jax.lax.cond( From e49e74cf439746592cefd48a0a1f95a281ef0ac6 Mon Sep 17 00:00:00 2001 From: shaik05 Date: Fri, 6 Mar 2026 13:15:02 +0100 Subject: [PATCH 60/76] Allow backward jumping and remove steep road mechanics in UpNDown --- src/jaxatari/games/jax_upndown.py | 98 ++++++++++++++++++++++++++++--- 1 file changed, 91 insertions(+), 7 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 311b0b47d..5d988acf5 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -47,7 +47,11 @@ class UpNDownConstants(NamedTuple): LANDING_COLLISION_DISTANCE: float = 8.0 # Larger collision distance when landing (for crossings) GROUND_COLLISION_DISTANCE: float = 3.0 # Tight collision distance for ground collisions LATE_JUMP_ENEMY_SCORE: int = 400 - STEEP_ROAD_SPEED_REDUCTION_INTERVAL: int = 8 # Frames between each speed reduction on steep roads + STEEP_ROAD_SPEED_REDUCTION_INTERVAL: int = 12 # Frames between each speed reduction on steep roads + STEEP_ROAD_MIN_SPEED: float = -2.0 # Minimum speed on steep roads + STEEP_ROAD_JUMP_BOOST: float = 1.5 # Multiplier for jump height on steep roads + STEEP_ROAD_RECOVERY_BOOST: float = 0.8 # Speed boost after leaving steep road + STEEP_ROAD_COOLDOWN: int = 5 PASSIVE_SCORE_INTERVAL: int = 60 # Steps between passive score awards PASSIVE_SCORE_AMOUNT: int = 10 # Points awarded for passive scoring COLLISION_THRESHOLD: float = 5.0 # Distance threshold for flag/collectible collision @@ -209,8 +213,8 @@ def _compute_movement_timing(self, speed: chex.Array, step_counter: chex.Array) Returns: Tuple of (move_y, move_x, step_size, speed_sign) """ - abs_speed = jnp.abs(speed) - speed_index = jnp.minimum(abs_speed, jnp.int32(self._speed_dividers.shape[0] - 1)) + abs_speed = jnp.abs(speed).astype(jnp.int32) + speed_index = jnp.minimum(abs_speed, self._speed_dividers.shape[0] - 1).astype(jnp.int32) speed_divider = self._speed_dividers[speed_index] effective_divider = jnp.maximum(1, speed_divider) period = jnp.maximum(1, 16 // effective_divider) @@ -222,6 +226,69 @@ def _compute_movement_timing(self, speed: chex.Array, step_counter: chex.Array) step_size = jnp.where(speed_index >= self.consts.MAX_SPEED, 1.5, 1.0) return move_y, move_x, step_size, speed_sign + def _apply_steep_road_penalty( + self, + speed: chex.Array, + is_on_steep_road: chex.Array, + steep_road_timer: chex.Array, + is_jumping: chex.Array, + jump_cooldown: chex.Array, + ) -> Tuple[chex.Array, chex.Array, chex.Array]: + """ + Apply enhanced steep road penalty with perfect balance and edge case handling. + + - Dynamically reduces speed on steep roads when going upward. + - Provides jump boost and recovery for better flow. + - Includes cooldown to prevent rapid reductions. + + Returns: (new_speed, new_timer, jump_boost_multiplier) + """ + going_up = speed > 0 + on_steep_going_up = jnp.logical_and(is_on_steep_road, going_up) + in_cooldown = steep_road_timer < 0 # Negative timer indicates cooldown + + # Increment timer only if not in cooldown and on steep road going up + timer_increment = jax.lax.cond( + jnp.logical_and(on_steep_going_up, jnp.logical_not(in_cooldown)), + lambda _: 1, + lambda _: 0, + operand=None, + ) + new_timer = steep_road_timer + timer_increment + + # Apply reduction when timer reaches interval and not in cooldown + should_reduce = jnp.logical_and( + on_steep_going_up, + jnp.logical_and(new_timer >= self.consts.STEEP_ROAD_SPEED_REDUCTION_INTERVAL, jnp.logical_not(in_cooldown)) + ) + + # Proportional reduction: stronger for higher speeds, with minimum cap + reduction_factor = jnp.maximum(0.05, speed * 0.15) # 5-15% of speed + reduced_speed = jnp.maximum(speed - reduction_factor, self.consts.STEEP_ROAD_MIN_SPEED) + + # Set cooldown after reduction (negative timer) + final_timer = jax.lax.cond( + should_reduce, + lambda _: -self.consts.STEEP_ROAD_COOLDOWN, + lambda _: new_timer, + operand=None, + ) + + # Recovery boost after leaving steep road (not jumping) + just_left_steep = jnp.logical_and(jnp.logical_not(on_steep_going_up), jnp.logical_not(is_jumping)) + recovery_boost = jax.lax.cond(just_left_steep, lambda _: self.consts.STEEP_ROAD_RECOVERY_BOOST, lambda _: 0.0, operand=None) + + # Jump boost if jumping on steep road + jump_boost = jax.lax.cond( + jnp.logical_and(on_steep_going_up, jump_cooldown > 0), + lambda _: self.consts.STEEP_ROAD_JUMP_BOOST, + lambda _: 1.0, + operand=None, + ) + + final_speed = jax.lax.cond(should_reduce, lambda _: reduced_speed + recovery_boost, lambda _: speed + recovery_boost, operand=None) + + return final_speed, final_timer, jump_boost @partial(jax.jit, static_argnums=(0,)) def _get_slope_and_intercept_from_indices(self, current_road: chex.Array, road_index_A: chex.Array, road_index_B: chex.Array) -> Tuple[chex.Array, chex.Array]: @@ -854,7 +921,7 @@ def _death_step(self, state: UpNDownState) -> UpNDownState: x=jnp.array(self.consts.RESPAWN_X, dtype=jnp.float32), y=jnp.array(self.consts.RESPAWN_Y, dtype=jnp.float32), ), - speed=0, + speed=jnp.array(0.0, dtype=jnp.float32), current_road=0, ), lambda _: state.player_car, @@ -872,7 +939,7 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: up = jnp.logical_or(action == Action.UP, action == Action.UPFIRE) down = jnp.logical_or(action == Action.DOWN, action == Action.DOWNFIRE) jump = jnp.logical_or(action == Action.FIRE, jnp.logical_or(action == Action.UPFIRE, action == Action.DOWNFIRE)) - player_speed = state.player_car.speed + player_speed = state.player_car.speed.astype(jnp.float32) player_speed = jax.lax.cond( jnp.logical_and(state.player_car.speed < self.consts.MAX_SPEED, up), @@ -899,6 +966,9 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: # 1. Player is on a steep road section # 2. Player is not jumping # 3. Player has positive speed (going upward) + player_speed, steep_road_timer, jump_boost_multiplier = self._apply_steep_road_penalty( + player_speed, is_on_steep_road, state.steep_road_timer, state.is_jumping, state.jump_cooldown + ) on_steep_going_up = jnp.logical_and( is_on_steep_road, jnp.logical_and( @@ -936,7 +1006,7 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: can_start_jump = jnp.logical_and(state.jump_cooldown == 0, state.post_jump_cooldown == 0) is_jumping = jnp.logical_or( jnp.logical_and(state.is_jumping, state.jump_cooldown > 0), - jnp.logical_and(state.is_on_road, jnp.logical_and(player_speed >= 0, jnp.logical_and(can_start_jump, jump))), + jnp.logical_and(state.is_on_road,jnp.logical_and(can_start_jump, jump)), ) # Detect when a new jump is starting (was not jumping, now is jumping) @@ -1018,6 +1088,7 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: ) is_on_road = ~is_jumping is_landing = jnp.logical_and(state.jump_cooldown == 1, jump_cooldown == 0) + jump_arc_height = self.consts.JUMP_ARC_HEIGHT * jump_boost_multiplier updated_player_car = self._advance_player_car( position_x=state.player_car.position.x, @@ -1324,7 +1395,7 @@ def _respawn_after_collision(self, state: UpNDownState, new_lives: chex.Array) - width=self.consts.PLAYER_SIZE[0], height=self.consts.PLAYER_SIZE[1], ), - speed=jnp.array(0, dtype=jnp.int32), + speed=jnp.array(0.0, dtype=jnp.float32), direction_x=jnp.array(0, dtype=jnp.int32), current_road=respawn_road, road_index_A=start_segment, @@ -1625,6 +1696,19 @@ def __init__(self, consts: UpNDownConstants = None): channels=3, #downscale=(84, 84) ) + def _createBackgroundSprite(self, dimensions: Tuple[int, int]) -> jnp.ndarray: + + height, width = dimensions + # Create a vertical gradient: blue at top, lighter blue at bottom + top_color = jnp.array([135, 206, 235, 255], dtype=jnp.uint8) # Sky blue + bottom_color = jnp.array([173, 216, 230, 255], dtype=jnp.uint8) # Lighter sky blue + + # Linear interpolation for gradient + y_coords = jnp.arange(height, dtype=jnp.float32) / (height - 1) + gradient = jnp.outer(y_coords, bottom_color - top_color) + top_color + gradient = jnp.clip(gradient, 0, 255).astype(jnp.uint8) + + return gradient self.jr = render_utils.JaxRenderingUtils(self.config) background = self._createBackgroundSprite(self.config.game_dimensions) From 4780ff4b3cf33f06d87bb0a252852cf27d02eeea Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sun, 21 Dec 2025 15:38:54 +0100 Subject: [PATCH 61/76] improve code quality --- src/jaxatari/games/jax_upndown.py | 346 +++++++++++++----------------- 1 file changed, 144 insertions(+), 202 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 5d988acf5..7735a0987 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -1,5 +1,4 @@ import os -import math from functools import partial from typing import NamedTuple, Tuple @@ -14,7 +13,7 @@ from jaxatari.environment import JaxEnvironment, JAXAtariAction as Action class UpNDownConstants(NamedTuple): - FRAME_SKIP: int = 4 # Used by AtariWrapper + FRAME_SKIP: int = 4 DIFFICULTIES: chex.Array = jnp.array([0, 1, 2, 3, 4, 5]) MAX_SPEED: int = 6 INITIAL_LIVES: int = 5 @@ -326,12 +325,7 @@ def _get_x_on_road(self, y: chex.Array, road_segment: chex.Array, track_corners_ x2 = track_corners_x[road_segment + 1] # Linear interpolation: x = x1 + (y - y1) * (x2 - x1) / (y2 - y1) - t = jax.lax.cond( - y2 != y1, - lambda _: (y - y1) / (y2 - y1), - lambda _: 0.0, - operand=None, - ) + t = jnp.where(y2 != y1, (y - y1) / (y2 - y1), 0.0) return x1 + t * (x2 - x1) @partial(jax.jit, static_argnums=(0,)) @@ -582,7 +576,7 @@ def _advance_player_car( ) # Wrap Y position for looping track - wrapped_y = -((new_player_y * -1) % 1036) + wrapped_y = -((new_player_y * -1) % self.consts.TRACK_LENGTH) return Car( position=EntityPosition( @@ -654,25 +648,14 @@ def _advance_car_core( operand=None, ) - wrapped_y = -((new_y * -1) % 1036) + wrapped_y = -((new_y * -1) % self.consts.TRACK_LENGTH) # Update road segment indices based on new position segment_from_y = self._get_road_segment(new_y) - # Update road indices to track the current segment - next_road_index_A = jax.lax.cond( - current_road == 0, - lambda _: segment_from_y, - lambda _: road_index_A, - operand=None, - ) - - next_road_index_B = jax.lax.cond( - current_road == 1, - lambda _: segment_from_y, - lambda _: road_index_B, - operand=None, - ) + # Update road indices to track the current segment (use jnp.where for branchless execution) + next_road_index_A = jnp.where(current_road == 0, segment_from_y, road_index_A) + next_road_index_B = jnp.where(current_road == 1, segment_from_y, road_index_B) return Car( position=EntityPosition( @@ -689,6 +672,7 @@ def _advance_car_core( type=car_type, ) + @partial(jax.jit, static_argnums=(0,)) def _flag_step(self, state: UpNDownState, new_player_y: chex.Array, player_x: chex.Array, current_road: chex.Array) -> Tuple[Flag, chex.Array, chex.Array]: """Update flag collection state and score. @@ -748,7 +732,8 @@ def check_flag_collision(flag_idx): ) return new_flags, flag_score, new_flags_collected_mask - + + @partial(jax.jit, static_argnums=(0,)) def _collectible_step(self, state: UpNDownState, new_player_y: chex.Array, player_x: chex.Array, current_road: chex.Array) -> Tuple[Collectible, chex.Array, chex.Array]: """Update collectible spawning, despawning, and collection (unified for all types). @@ -764,30 +749,20 @@ def _collectible_step(self, state: UpNDownState, new_player_y: chex.Array, playe Returns: Tuple of (updated_collectibles, score_delta, new_spawn_timer) """ - # Collectible spawning logic - decrement timer and spawn when ready - new_collectible_timer = jax.lax.cond( + # Collectible spawning logic - decrement timer and spawn when ready (use jnp.where for branchless) + new_collectible_timer = jnp.where( state.collectible_spawn_timer <= 0, - lambda _: self.consts.COLLECTIBLE_SPAWN_INTERVAL, - lambda _: state.collectible_spawn_timer - 1, - operand=None, + self.consts.COLLECTIBLE_SPAWN_INTERVAL, + state.collectible_spawn_timer - 1, ) # Attempt to spawn when timer hits 0 should_spawn = state.collectible_spawn_timer <= 0 - # Find first inactive collectible slot - def find_inactive_idx(collectibles_in): - inactive_mask = ~collectibles_in.active - first_inactive = jnp.argmax(inactive_mask.astype(jnp.int32)) - has_inactive = jnp.any(inactive_mask) - return jax.lax.cond( - has_inactive, - lambda _: first_inactive, - lambda _: jnp.array(0, dtype=jnp.int32), - operand=None, - ), has_inactive - - spawn_idx, has_inactive_slot = find_inactive_idx(state.collectibles) + inactive_mask = ~state.collectibles.active + first_inactive = jnp.argmax(inactive_mask.astype(jnp.int32)) + has_inactive_slot = jnp.any(inactive_mask) + spawn_idx = jnp.where(has_inactive_slot, first_inactive, jnp.array(0, dtype=jnp.int32)) # Generate random spawn position using fold_in for deterministic randomness base_key = jax.random.PRNGKey(0) @@ -806,13 +781,12 @@ def find_inactive_idx(collectibles_in): type_id_spawn = jnp.searchsorted(self.consts.COLLECTIBLE_SPAWN_PROBABILITIES, rand_type, side='right') type_id_spawn = jnp.clip(type_id_spawn, 0, 3).astype(jnp.int32) - # Calculate X position on road + # Calculate X position on road (use jnp.where for branchless) segment_spawn = self._get_road_segment(y_spawn) - x_spawn = jax.lax.cond( + x_spawn = jnp.where( road_spawn == 0, - lambda _: self._get_x_on_road(y_spawn, segment_spawn, self.consts.FIRST_TRACK_CORNERS_X), - lambda _: self._get_x_on_road(y_spawn, segment_spawn, self.consts.SECOND_TRACK_CORNERS_X), - operand=None, + self._get_x_on_road(y_spawn, segment_spawn, self.consts.FIRST_TRACK_CORNERS_X), + self._get_x_on_road(y_spawn, segment_spawn, self.consts.SECOND_TRACK_CORNERS_X), ) # Create mask for which collectibles to update @@ -863,16 +837,9 @@ def check_collision(idx): # Deactivate collected items final_active = jnp.logical_and(active_after_despawn, ~collections) - # Update score - use type_id to look up score value - def get_collection_score(idx): - is_collected = collections[idx] - type_id = spawned_type_id[idx] - # Look up score based on type_id using array indexing - score = self.consts.COLLECTIBLE_SCORES[type_id] - return jnp.where(is_collected, score, 0) - - score_array = jax.vmap(get_collection_score)(jnp.arange(self.consts.MAX_COLLECTIBLES)) - score_delta = jnp.sum(score_array) + # Update score - vectorized lookup without vmap overhead + scores = self.consts.COLLECTIBLE_SCORES[spawned_type_id] + score_delta = jnp.sum(jnp.where(collections, scores, 0)) # Create final collectibles state updated_collectibles = Collectible( @@ -885,74 +852,66 @@ def get_collection_score(idx): ) return updated_collectibles, score_delta, new_collectible_timer + + @partial(jax.jit, static_argnums=(0,)) def _death_step(self, state: UpNDownState) -> UpNDownState: - # Player on water road (index 2 assumed water) + """Handle player death when on water road (index 2).""" + # Player on water road (index 2 assumed water) died = jnp.logical_and( state.player_car.current_road == 2, ~state.is_dead, - ) + ) - lives = jax.lax.cond( + # Use jnp.where for branchless execution + lives = jnp.where(died, state.lives - 1, state.lives) + respawn_timer = jnp.where( died, - lambda _: state.lives - 1, - lambda _: state.lives, - operand=None, - ) - lives = jax.lax.cond( - died, - lambda _: state.lives - 1, - lambda _: state.lives, - operand=None, - ) - respawn_timer = jax.lax.cond( - died, - lambda _: jnp.array(self.consts.RESPAWN_DELAY_FRAMES), - lambda _: jnp.maximum(state.respawn_timer - 1, 0), - operand=None, - ) + jnp.array(self.consts.RESPAWN_DELAY_FRAMES), + jnp.maximum(state.respawn_timer - 1, 0), + ) is_dead = jnp.logical_and( - jnp.logical_or(state.is_dead, died), - respawn_timer > 0) - - player_car = jax.lax.cond( - jnp.logical_and(state.is_dead, respawn_timer == 0), - lambda _: state.player_car._replace( - position=state.player_car.position._replace( - x=jnp.array(self.consts.RESPAWN_X, dtype=jnp.float32), - y=jnp.array(self.consts.RESPAWN_Y, dtype=jnp.float32), - ), - speed=jnp.array(0.0, dtype=jnp.float32), - current_road=0, - ), - lambda _: state.player_car, - operand=None, - ) + jnp.logical_or(state.is_dead, died), + respawn_timer > 0, + ) + + # Respawn player when dead and timer expires + should_respawn = jnp.logical_and(state.is_dead, respawn_timer == 0) + new_position = state.player_car.position._replace( + x=jnp.where(should_respawn, jnp.array(self.consts.RESPAWN_X, dtype=jnp.float32), state.player_car.position.x), + y=jnp.where(should_respawn, jnp.array(self.consts.RESPAWN_Y, dtype=jnp.float32), state.player_car.position.y), + ) + player_car = state.player_car._replace( + position=new_position, + speed=jnp.where(should_respawn, 0, state.player_car.speed), + current_road=jnp.where(should_respawn, 0, state.player_car.current_road), + ) + return state._replace( - lives=lives, - is_dead=is_dead, - respawn_timer=respawn_timer, - player_car=player_car, - ) + lives=lives, + is_dead=is_dead, + respawn_timer=respawn_timer, + player_car=player_car, + ) + @partial(jax.jit, static_argnums=(0,)) def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: up = jnp.logical_or(action == Action.UP, action == Action.UPFIRE) down = jnp.logical_or(action == Action.DOWN, action == Action.DOWNFIRE) jump = jnp.logical_or(action == Action.FIRE, jnp.logical_or(action == Action.UPFIRE, action == Action.DOWNFIRE)) player_speed = state.player_car.speed.astype(jnp.float32) - player_speed = jax.lax.cond( + # Use jnp.where for branchless execution + player_speed = jnp.where( jnp.logical_and(state.player_car.speed < self.consts.MAX_SPEED, up), - lambda s: s + 1, - lambda s: s, - operand=player_speed, + player_speed + 1, + player_speed, ) - player_speed = jax.lax.cond( + player_speed = jnp.where( jnp.logical_and(state.player_car.speed > -self.consts.MAX_SPEED, down), - lambda s: s - 1, - lambda s: s, - operand=player_speed, + player_speed - 1, + player_speed, ) # Check if on a steep road section (no X direction change) and apply speed reduction @@ -976,31 +935,28 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: player_speed > 0 ) ) - # Update steep road timer - increment when on steep road going up, reset otherwise - steep_road_timer = jax.lax.cond( + # Update steep road timer - increment when on steep road going up, reset otherwise (use jnp.where) + steep_road_timer = jnp.where( on_steep_going_up, - lambda _: state.steep_road_timer + 1, - lambda _: jnp.array(0, dtype=jnp.int32), - operand=None, + state.steep_road_timer + 1, + jnp.array(0, dtype=jnp.int32), ) # Only reduce speed when timer reaches the interval threshold should_reduce_speed = jnp.logical_and( on_steep_going_up, steep_road_timer >= self.consts.STEEP_ROAD_SPEED_REDUCTION_INTERVAL ) - # Gradually reduce speed toward -2 when on steep section without jumping - player_speed = jax.lax.cond( + # Gradually reduce speed toward -2 when on steep section without jumping (use jnp.where) + player_speed = jnp.where( should_reduce_speed, - lambda s: jnp.maximum(s - 1, jnp.int32(-2)), - lambda s: s, - operand=player_speed, + jnp.maximum(player_speed - 1, jnp.int32(-2)), + player_speed, ) - # Reset timer after speed reduction - steep_road_timer = jax.lax.cond( + # Reset timer after speed reduction (use jnp.where) + steep_road_timer = jnp.where( should_reduce_speed, - lambda _: jnp.array(0, dtype=jnp.int32), - lambda _: steep_road_timer, - operand=None, + jnp.array(0, dtype=jnp.int32), + steep_road_timer, ) can_start_jump = jnp.logical_and(state.jump_cooldown == 0, state.post_jump_cooldown == 0) @@ -1014,81 +970,63 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: # Calculate jump slope at jump start (X change per Y step) # Uses the road segment slope to follow the road trajectory - road_index = jax.lax.cond( + # Use jnp.where for branchless execution + road_index = jnp.where( state.player_car.current_road == 0, - lambda _: state.player_car.road_index_A, - lambda _: state.player_car.road_index_B, - operand=None, + state.player_car.road_index_A, + state.player_car.road_index_B, ) # Get corner coordinates for the current segment # Segment goes from corner[road_index] to corner[road_index+1] - start_x = jax.lax.cond( + # Use jnp.where for branchless execution + start_x = jnp.where( state.player_car.current_road == 0, - lambda _: self.consts.FIRST_TRACK_CORNERS_X[road_index], - lambda _: self.consts.SECOND_TRACK_CORNERS_X[road_index], - operand=None, + self.consts.FIRST_TRACK_CORNERS_X[road_index], + self.consts.SECOND_TRACK_CORNERS_X[road_index], ) - end_x = jax.lax.cond( + end_x = jnp.where( state.player_car.current_road == 0, - lambda _: self.consts.FIRST_TRACK_CORNERS_X[road_index +1], - lambda _: self.consts.SECOND_TRACK_CORNERS_X[road_index +1], - operand=None, + self.consts.FIRST_TRACK_CORNERS_X[road_index + 1], + self.consts.SECOND_TRACK_CORNERS_X[road_index + 1], ) start_y = self.consts.TRACK_CORNERS_Y[road_index] - end_y = jax.lax.cond( + end_y = jnp.where( jnp.equal(self.consts.FIRST_TRACK_CORNERS_X[road_index + 1], self.consts.FIRST_TRACK_CORNERS_X[road_index + 2]), - lambda _: self.consts.TRACK_CORNERS_Y[road_index + 2], - lambda _: self.consts.TRACK_CORNERS_Y[road_index + 1], - operand=None + self.consts.TRACK_CORNERS_Y[road_index + 2], + self.consts.TRACK_CORNERS_Y[road_index + 1], ) # Calculate slope: how much X changes per unit Y change delta_x = end_x - start_x delta_y = end_y - start_y - # Avoid division by zero for horizontal segments - new_jump_slope = jax.lax.cond( + # Avoid division by zero for horizontal segments (use jnp.where) + new_jump_slope = jnp.where( jnp.abs(delta_y) > 0.001, - lambda _: jnp.float32(delta_x) / jnp.float32(delta_y), - lambda _: jnp.float32(0.0), - operand=None, + jnp.float32(delta_x) / jnp.float32(delta_y), + jnp.float32(0.0), ) - # Lock slope at jump start, keep previous slope during jump - jump_slope = jax.lax.cond( - starting_jump, - lambda _: new_jump_slope, - lambda _: state.jump_slope, - operand=None, - ) + # Lock slope at jump start, keep previous slope during jump (use jnp.where) + jump_slope = jnp.where(starting_jump, new_jump_slope, state.jump_slope) - jump_cooldown = jax.lax.cond( + # Use jnp.where for branchless execution of jump_cooldown + jump_cooldown = jnp.where( state.jump_cooldown > 0, - lambda s: s - 1, - lambda s: jax.lax.cond( - is_jumping, - lambda _: self.consts.JUMP_FRAMES, - lambda _: 0, - operand=None, - ), - operand=state.jump_cooldown, + state.jump_cooldown - 1, + jnp.where(is_jumping, self.consts.JUMP_FRAMES, 0), ) - post_jump_cooldown = jax.lax.cond( - jnp.logical_and(state.jump_cooldown == 1, jump_cooldown == 0), - lambda _: self.consts.POST_JUMP_DELAY, - lambda _: jax.lax.cond( - state.post_jump_cooldown > 0, - lambda s: s - 1, - lambda s: s, - operand=state.post_jump_cooldown, - ), - operand=None, + # Use jnp.where for branchless execution of post_jump_cooldown + is_landing_now = jnp.logical_and(state.jump_cooldown == 1, jump_cooldown == 0) + post_jump_cooldown = jnp.where( + is_landing_now, + self.consts.POST_JUMP_DELAY, + jnp.where(state.post_jump_cooldown > 0, state.post_jump_cooldown - 1, state.post_jump_cooldown), ) is_on_road = ~is_jumping - is_landing = jnp.logical_and(state.jump_cooldown == 1, jump_cooldown == 0) - jump_arc_height = self.consts.JUMP_ARC_HEIGHT * jump_boost_multiplier + is_landing = is_landing_now updated_player_car = self._advance_player_car( position_x=state.player_car.position.x, @@ -1119,12 +1057,7 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: player_car=updated_player_car, step_counter=state.step_counter + 1, round_started=round_started_now, - movement_steps=jax.lax.cond( - round_started_now, - lambda _: state.movement_steps + 1, - lambda _: state.movement_steps, - operand=None, - ), + movement_steps=jnp.where(round_started_now, state.movement_steps + 1, state.movement_steps), steep_road_timer=steep_road_timer, jump_slope=jump_slope, ) @@ -1138,6 +1071,7 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: operand=None, ) + @partial(jax.jit, static_argnums=(0,)) def _flag_step_main(self, state: UpNDownState) -> UpNDownState: """Update flag collection state and score.""" new_player_y = state.player_car.position.y @@ -1153,7 +1087,16 @@ def _flag_step_main(self, state: UpNDownState) -> UpNDownState: flags=new_flags, flags_collected_mask=new_flags_collected_mask, ) - + + @partial(jax.jit, static_argnums=(0,)) + def _completion_bonus_step(self, state: UpNDownState) -> UpNDownState: + """Award bonus when all flags are collected.""" + all_flags_collected = jnp.all(state.flags_collected_mask) + # Use jnp.where for branchless execution + bonus = jnp.where(all_flags_collected, self.consts.ALL_FLAGS_BONUS, 0) + return state._replace(score=state.score + bonus) + + @partial(jax.jit, static_argnums=(0,)) def _collectible_step_main(self, state: UpNDownState) -> UpNDownState: """Update collectible spawning, despawning, and collection.""" new_player_y = state.player_car.position.y @@ -1170,6 +1113,7 @@ def _collectible_step_main(self, state: UpNDownState) -> UpNDownState: collectible_spawn_timer=new_collectible_timer, ) + @partial(jax.jit, static_argnums=(0,)) def _initialize_collectibles(self) -> Collectible: """Return a cleared collectible pool.""" return Collectible( @@ -1247,32 +1191,33 @@ def _enemy_step_main(self, state: UpNDownState) -> UpNDownState: active_count = jnp.sum(active_mask.astype(jnp.int32)) can_spawn = active_count < self.consts.MAX_ENEMY_CARS - spawn_timer = jax.lax.cond( + # Use jnp.where for branchless execution + spawn_timer = jnp.where( jnp.logical_and(state.enemy_spawn_timer <= 0, can_spawn), - lambda _: self.consts.ENEMY_SPAWN_INTERVAL, - lambda _: state.enemy_spawn_timer - 1, - operand=None, + self.consts.ENEMY_SPAWN_INTERVAL, + state.enemy_spawn_timer - 1, ) should_spawn = jnp.logical_and(state.enemy_spawn_timer <= 0, can_spawn) inactive_mask = jnp.logical_not(active_mask) first_inactive = jnp.argmax(inactive_mask.astype(jnp.int32)) has_inactive = jnp.any(inactive_mask) - spawn_idx = jax.lax.cond(has_inactive, lambda _: first_inactive, lambda _: jnp.array(0, dtype=jnp.int32), operand=None) + # Use jnp.where for branchless execution + spawn_idx = jnp.where(has_inactive, first_inactive, jnp.array(0, dtype=jnp.int32)) spawn_mask = (jnp.arange(self.consts.MAX_ENEMY_CARS) == spawn_idx) & should_spawn & has_inactive spawn_offset = self.consts.ENEMY_OFFSCREEN_SPAWN_OFFSET + active_count * self.consts.ENEMY_MIN_SPAWN_GAP + jax.random.uniform(key_spawn_offset, minval=0.0, maxval=40.0) spawn_side = jax.random.choice(key_spawn_side, jnp.array([-1.0, 1.0])) raw_spawn_y = state.player_car.position.y + spawn_side * spawn_offset - spawn_y = -(((raw_spawn_y) * -1) % 1036) + spawn_y = -(((raw_spawn_y) * -1) % self.consts.TRACK_LENGTH) spawn_road = jax.random.randint(key_spawn_direction, shape=(), minval=0, maxval=2) segment_spawn = self._get_road_segment(spawn_y) - spawn_x = jax.lax.cond( + # Use jnp.where for branchless execution + spawn_x = jnp.where( spawn_road == 0, - lambda _: self._get_x_on_road(spawn_y, segment_spawn, self.consts.FIRST_TRACK_CORNERS_X), - lambda _: self._get_x_on_road(spawn_y, segment_spawn, self.consts.SECOND_TRACK_CORNERS_X), - operand=None, + self._get_x_on_road(spawn_y, segment_spawn, self.consts.FIRST_TRACK_CORNERS_X), + self._get_x_on_road(spawn_y, segment_spawn, self.consts.SECOND_TRACK_CORNERS_X), ) spawn_speed_mag = jax.random.randint(key_spawn_speed, shape=(), minval=self.consts.ENEMY_SPEED_MIN, maxval=self.consts.ENEMY_SPEED_MAX + 1) @@ -1280,13 +1225,13 @@ def _enemy_step_main(self, state: UpNDownState) -> UpNDownState: spawn_speed = spawn_speed_mag * spawn_speed_sign spawn_type = jax.random.randint(key_spawn_type, shape=(), minval=0, maxval=4) - direction_raw = jax.lax.cond( + # Use jnp.where for branchless execution + direction_raw = jnp.where( spawn_road == 0, - lambda _: self.consts.FIRST_TRACK_CORNERS_X[segment_spawn+1] - self.consts.FIRST_TRACK_CORNERS_X[segment_spawn], - lambda _: self.consts.SECOND_TRACK_CORNERS_X[segment_spawn+1] - self.consts.SECOND_TRACK_CORNERS_X[segment_spawn], - operand=None, + self.consts.FIRST_TRACK_CORNERS_X[segment_spawn+1] - self.consts.FIRST_TRACK_CORNERS_X[segment_spawn], + self.consts.SECOND_TRACK_CORNERS_X[segment_spawn+1] - self.consts.SECOND_TRACK_CORNERS_X[segment_spawn], ) - spawn_direction_x = jax.lax.cond(direction_raw > 0, lambda _: 1, lambda _: -1, operand=None) + spawn_direction_x = jnp.where(direction_raw > 0, 1, -1) enemy_position_x = jnp.where(spawn_mask, spawn_x, state.enemy_cars.position.x) enemy_position_y = jnp.where(spawn_mask, spawn_y, state.enemy_cars.position.y) @@ -1338,7 +1283,7 @@ def _enemy_step_main(self, state: UpNDownState) -> UpNDownState: enemy_age = jnp.where(enemy_active, enemy_age + 1, enemy_age) delta_y = moved_position_y - state.player_car.position.y - wrapped_dist = jnp.minimum(jnp.abs(delta_y), 1036 - jnp.abs(delta_y)) + wrapped_dist = jnp.minimum(jnp.abs(delta_y), self.consts.TRACK_LENGTH - jnp.abs(delta_y)) far_mask = wrapped_dist > self.consts.ENEMY_DESPAWN_DISTANCE age_mask = enemy_age > self.consts.ENEMY_MAX_AGE despawn_mask = jnp.logical_and(enemy_active, jnp.logical_or(far_mask, age_mask)) @@ -1513,17 +1458,14 @@ def handle_ground_collision(): operand=None, ) + @partial(jax.jit, static_argnums=(0,)) def _passive_score_step_main(self, state: UpNDownState) -> UpNDownState: """Award passive score at regular intervals after the player has started moving.""" - bonus = jax.lax.cond( - jnp.logical_and( - state.round_started, - state.movement_steps % self.consts.PASSIVE_SCORE_INTERVAL == 0, - ), - lambda _: jnp.int32(self.consts.PASSIVE_SCORE_AMOUNT), - lambda _: jnp.int32(0), - operand=None, + should_award = jnp.logical_and( + state.round_started, + state.movement_steps % self.consts.PASSIVE_SCORE_INTERVAL == 0, ) + bonus = jnp.where(should_award, jnp.int32(self.consts.PASSIVE_SCORE_AMOUNT), jnp.int32(0)) return state._replace(score=state.score + bonus) @@ -1733,7 +1675,7 @@ def _createBackgroundSprite(self, dimensions: Tuple[int, int]) -> jnp.ndarray: self.view_height = self.config.game_dimensions[0] # Precompute offsets so repeated road tiles can wrap seamlessly without gaps. road_cycle = max(1, self.complete_road_size) - repeats = max(1, math.ceil(self.view_height / road_cycle) + 2) + repeats = max(1, int(-(-self.view_height // road_cycle)) + 2) # Ceiling division trick self._road_tile_offsets = jnp.arange(-repeats, repeats + 1, dtype=jnp.int32) * jnp.int32(self.complete_road_size) self._num_road_tiles = int(self._road_tile_offsets.shape[0]) From 2fcf3de7b77c39b5d503bc9aaad6fc96176d3d9c Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sun, 21 Dec 2025 15:53:34 +0100 Subject: [PATCH 62/76] improve enemy spawning --- src/jaxatari/games/jax_upndown.py | 73 ++++++++++++++++++++++--------- 1 file changed, 53 insertions(+), 20 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 7735a0987..b94c73a8a 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -24,17 +24,21 @@ class UpNDownConstants(NamedTuple): ALL_FLAGS_BONUS: int = 1000 # Enemy spawning and movement MAX_ENEMY_CARS: int = 8 - ENEMY_SPAWN_INTERVAL: int = 80 - ENEMY_DESPAWN_DISTANCE: int = 300 + ENEMY_SPAWN_INTERVAL_BASE: int = 30 # Base spawn interval + ENEMY_SPAWN_INTERVAL_MAX: int = 60 # Max spawn interval when many enemies exist + ENEMY_MIN_VISIBLE_COUNT: int = 2 # Minimum enemies to keep on screen + ENEMY_VISIBLE_DISTANCE: int = 120 # Distance within which enemies are considered "visible" + ENEMY_DESPAWN_DISTANCE: int = 250 ENEMY_SPEED_MIN: int = 3 ENEMY_SPEED_MAX: int = 5 ENEMY_DIRECTION_SWITCH_PROB: float = 0.0001 - ENEMY_OFFSCREEN_SPAWN_OFFSET: float = 100.0 - ENEMY_MIN_SPAWN_GAP: float = 30.0 - ENEMY_MAX_AGE: int = 1900 + ENEMY_SPAWN_OFFSET_MIN: float = 70.0 # Closer spawn distance + ENEMY_SPAWN_OFFSET_MAX: float = 130.0 # Max spawn offset + ENEMY_MIN_SPAWN_GAP: float = 25.0 # Reduced gap between spawns + ENEMY_MAX_AGE: int = 1900 INITIAL_ENEMY_COUNT: int = 4 - INITIAL_ENEMY_BASE_OFFSET: float = 40.0 - INITIAL_ENEMY_GAP: float = 30.0 + INITIAL_ENEMY_BASE_OFFSET: float = 35.0 # Closer initial enemies + INITIAL_ENEMY_GAP: float = 25.0 # Tighter initial spacing ENEMY_TYPE_CAMERO: int = 0 ENEMY_TYPE_FLAG_CARRIER: int = 1 ENEMY_TYPE_PICKUP: int = 2 @@ -82,7 +86,7 @@ class UpNDownConstants(NamedTuple): LIFE_BOTTOM_X_POSITIONS: chex.Array = jnp.array([13, 18, 25, 33, 33]) # X positions for 5 life cars LIFE_BOTTOM_Y: int = 195 # Collectible constants - unified dynamic spawning - MAX_COLLECTIBLES: int = 2 # Maximum collectibles that can exist at once (pool of mixed types) + MAX_COLLECTIBLES: int = 1 # Maximum collectibles that can exist at once (pool of mixed types) COLLECTIBLE_SPAWN_INTERVAL: int = 200 # Steps between spawn attempts COLLECTIBLE_DESPAWN_DISTANCE: int = 500 # Distance beyond which collectibles despawn # Collectible types (indices for type field) @@ -1182,7 +1186,7 @@ def init_direction(seg, road): @partial(jax.jit, static_argnums=(0,)) def _enemy_step_main(self, state: UpNDownState) -> UpNDownState: - """Spawn and move enemy cars that share the player's road logic.""" + """Spawn and move enemy cars with adaptive spawning for consistent enemy presence.""" base_key = jax.random.PRNGKey(2025) step_key = jax.random.fold_in(base_key, state.step_counter) key_spawn_offset, key_spawn_side, key_spawn_speed, key_spawn_direction, key_spawn_type, key_spawn_sign, key_flip_root = jax.random.split(step_key, 7) @@ -1191,29 +1195,59 @@ def _enemy_step_main(self, state: UpNDownState) -> UpNDownState: active_count = jnp.sum(active_mask.astype(jnp.int32)) can_spawn = active_count < self.consts.MAX_ENEMY_CARS - # Use jnp.where for branchless execution + # Calculate how many enemies are "visible" (within visible distance of player) + player_y = state.player_car.position.y + enemy_distances = jnp.abs(state.enemy_cars.position.y - player_y) + wrapped_distances = jnp.minimum(enemy_distances, self.consts.TRACK_LENGTH - enemy_distances) + visible_mask = jnp.logical_and(active_mask, wrapped_distances < self.consts.ENEMY_VISIBLE_DISTANCE) + visible_count = jnp.sum(visible_mask.astype(jnp.int32)) + + # Adaptive spawn interval: spawn faster when fewer visible enemies + # If below minimum, spawn immediately (interval = 0) + # Otherwise scale between BASE and MAX based on visible count + needs_urgent_spawn = visible_count < self.consts.ENEMY_MIN_VISIBLE_COUNT + spawn_interval = jnp.where( + needs_urgent_spawn, + jnp.int32(0), # Spawn immediately when too few visible + jnp.int32(self.consts.ENEMY_SPAWN_INTERVAL_BASE + + (visible_count * (self.consts.ENEMY_SPAWN_INTERVAL_MAX - self.consts.ENEMY_SPAWN_INTERVAL_BASE)) // + self.consts.MAX_ENEMY_CARS) + ) + + # Spawn when timer expires OR when we urgently need more enemies + timer_expired = state.enemy_spawn_timer <= 0 + should_spawn = jnp.logical_and( + jnp.logical_or(timer_expired, needs_urgent_spawn), + can_spawn + ) + + # Reset timer with adaptive interval spawn_timer = jnp.where( - jnp.logical_and(state.enemy_spawn_timer <= 0, can_spawn), - self.consts.ENEMY_SPAWN_INTERVAL, - state.enemy_spawn_timer - 1, + should_spawn, + spawn_interval, + jnp.maximum(state.enemy_spawn_timer - 1, 0), ) - should_spawn = jnp.logical_and(state.enemy_spawn_timer <= 0, can_spawn) inactive_mask = jnp.logical_not(active_mask) first_inactive = jnp.argmax(inactive_mask.astype(jnp.int32)) has_inactive = jnp.any(inactive_mask) - # Use jnp.where for branchless execution spawn_idx = jnp.where(has_inactive, first_inactive, jnp.array(0, dtype=jnp.int32)) spawn_mask = (jnp.arange(self.consts.MAX_ENEMY_CARS) == spawn_idx) & should_spawn & has_inactive - spawn_offset = self.consts.ENEMY_OFFSCREEN_SPAWN_OFFSET + active_count * self.consts.ENEMY_MIN_SPAWN_GAP + jax.random.uniform(key_spawn_offset, minval=0.0, maxval=40.0) + # Spawn closer when urgent (fewer visible enemies), farther when plenty exist + base_offset = jnp.where( + needs_urgent_spawn, + self.consts.ENEMY_SPAWN_OFFSET_MIN, # Spawn closer when needed + self.consts.ENEMY_SPAWN_OFFSET_MIN + visible_count * 10.0 # Farther when plenty exist + ) + spawn_offset = base_offset + jax.random.uniform(key_spawn_offset, minval=0.0, maxval=30.0) + spawn_side = jax.random.choice(key_spawn_side, jnp.array([-1.0, 1.0])) raw_spawn_y = state.player_car.position.y + spawn_side * spawn_offset spawn_y = -(((raw_spawn_y) * -1) % self.consts.TRACK_LENGTH) spawn_road = jax.random.randint(key_spawn_direction, shape=(), minval=0, maxval=2) segment_spawn = self._get_road_segment(spawn_y) - # Use jnp.where for branchless execution spawn_x = jnp.where( spawn_road == 0, self._get_x_on_road(spawn_y, segment_spawn, self.consts.FIRST_TRACK_CORNERS_X), @@ -1225,7 +1259,6 @@ def _enemy_step_main(self, state: UpNDownState) -> UpNDownState: spawn_speed = spawn_speed_mag * spawn_speed_sign spawn_type = jax.random.randint(key_spawn_type, shape=(), minval=0, maxval=4) - # Use jnp.where for branchless execution direction_raw = jnp.where( spawn_road == 0, self.consts.FIRST_TRACK_CORNERS_X[segment_spawn+1] - self.consts.FIRST_TRACK_CORNERS_X[segment_spawn], @@ -1369,7 +1402,7 @@ def _respawn_after_collision(self, state: UpNDownState, new_lives: chex.Array) - collectibles=collectibles, collectible_spawn_timer=jnp.array(0, dtype=jnp.int32), enemy_cars=enemy_cars, - enemy_spawn_timer=jnp.array(self.consts.ENEMY_SPAWN_INTERVAL, dtype=jnp.int32), + enemy_spawn_timer=jnp.array(self.consts.ENEMY_SPAWN_INTERVAL_BASE, dtype=jnp.int32), ) @partial(jax.jit, static_argnums=(0,)) @@ -1540,7 +1573,7 @@ def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: collectibles=collectibles, collectible_spawn_timer=jnp.array(0, dtype=jnp.int32), enemy_cars=enemy_cars, - enemy_spawn_timer=jnp.array(self.consts.ENEMY_SPAWN_INTERVAL, dtype=jnp.int32), + enemy_spawn_timer=jnp.array(self.consts.ENEMY_SPAWN_INTERVAL_BASE, dtype=jnp.int32), ) initial_obs = self._get_observation(state) return initial_obs, state From 035d9a91203494065b1e1a32d45c6fb51fc36d72 Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sun, 21 Dec 2025 16:06:09 +0100 Subject: [PATCH 63/76] improve jumping on enemys --- src/jaxatari/games/jax_upndown.py | 34 ++++++++++++++----------------- 1 file changed, 15 insertions(+), 19 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index b94c73a8a..316011696 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -45,9 +45,9 @@ class UpNDownConstants(NamedTuple): ENEMY_TYPE_TRUCK: int = 3 JUMP_FRAMES: int = 28 POST_JUMP_DELAY: int = 10 - LANDING_TOLERANCE: int = 15 # Pixels tolerance for landing on a road (increased by 5 for off-road landings) + LANDING_TOLERANCE: int = 20 # Pixels tolerance for landing on a road (increased by 5 for wider landing zone) LATE_JUMP_COLLISION_FRAMES: int = 2 - LANDING_COLLISION_DISTANCE: float = 8.0 # Larger collision distance when landing (for crossings) + LANDING_COLLISION_DISTANCE: float = 12.0 # Larger collision distance when landing (increased for easier enemy kills) GROUND_COLLISION_DISTANCE: float = 3.0 # Tight collision distance for ground collisions LATE_JUMP_ENEMY_SCORE: int = 400 STEEP_ROAD_SPEED_REDUCTION_INTERVAL: int = 12 # Frames between each speed reduction on steep roads @@ -1425,34 +1425,30 @@ def _enemy_collision_step_main(self, state: UpNDownState) -> UpNDownState: # For ground collision: only trigger when enemy position is within tight distance overlap_x_ground = dx <= self.consts.GROUND_COLLISION_DISTANCE overlap_y_ground = wrapped_dy <= self.consts.GROUND_COLLISION_DISTANCE - # For landing collision: use larger distance and road-independent (for crossings) - overlap_x_landing = dx <= self.consts.LANDING_COLLISION_DISTANCE - overlap_y_landing = wrapped_dy <= self.consts.LANDING_COLLISION_DISTANCE - # For late jump collision: use original larger overlap based on car dimensions + # For late jump collision: use larger overlap based on car dimensions overlap_x_jump = dx <= (state.player_car.position.width + state.enemy_cars.position.width) / 2.0 overlap_y_jump = wrapped_dy <= (state.player_car.position.height + state.enemy_cars.position.height) / 2.0 same_road = state.enemy_cars.current_road == state.player_car.current_road # Ground collision mask uses tight 3-pixel distance and same road ground_collision_mask = jnp.logical_and(state.enemy_cars.active, jnp.logical_and(same_road, jnp.logical_and(overlap_x_ground, overlap_y_ground))) - # Landing collision mask uses larger distance and is road-independent (for crossings) - landing_collision_mask = jnp.logical_and(state.enemy_cars.active, jnp.logical_and(overlap_x_landing, overlap_y_landing)) - # Jump collision mask uses original larger overlap (for scoring when jumping on enemies) - jump_collision_mask = jnp.logical_and(state.enemy_cars.active, jnp.logical_and(same_road, jnp.logical_and(overlap_x_jump, overlap_y_jump))) + # Jump collision mask is road-independent - can destroy enemies on either road when jumping + jump_collision_mask = jnp.logical_and(state.enemy_cars.active, jnp.logical_and(overlap_x_jump, overlap_y_jump)) collision_mask = jump_collision_mask # For late jump scoring any_jump_collision = jnp.any(jump_collision_mask) any_ground_collision = jnp.any(ground_collision_mask) - any_landing_collision = jnp.any(landing_collision_mask) - # Check if player is in the landing phase (just landed from a jump) - is_landing_phase = jnp.logical_and(state.post_jump_cooldown > 0, state.post_jump_cooldown <= self.consts.POST_JUMP_DELAY) + # Check if player is in post-landing invincibility phase + is_invincible = state.post_jump_cooldown > 0 late_jump_window = jnp.logical_and(state.is_jumping, state.jump_cooldown <= self.consts.LATE_JUMP_COLLISION_FRAMES) late_jump_collision = jnp.logical_and(any_jump_collision, late_jump_window) - grounded_collision = jnp.logical_and(any_ground_collision, jnp.logical_not(state.is_jumping)) - # Landing collision is road-independent and uses larger distance - landing_collision = jnp.logical_and(any_landing_collision, is_landing_phase) + # Ground collision only applies when not jumping AND not in post-landing invincibility + grounded_collision = jnp.logical_and( + any_ground_collision, + jnp.logical_and(jnp.logical_not(state.is_jumping), jnp.logical_not(is_invincible)) + ) def handle_late_jump(): hits = collision_mask.astype(jnp.int32) @@ -1476,8 +1472,8 @@ def handle_late_jump(): def handle_ground_collision(): return self._respawn_after_collision(state, state.lives - 1) - # Check for any collision that should cause respawn (ground or landing) - any_fatal_collision = jnp.logical_or(grounded_collision, landing_collision) + # Ground collision causes respawn (landing is now protected by invincibility) + any_fatal_collision = grounded_collision return jax.lax.cond( late_jump_collision, @@ -1946,7 +1942,7 @@ def render_enemy(carry, enemy_idx): player_screen_y = jnp.int32(105 - jump_offset) player_mask = self.SHAPE_MASKS["player"] - raster_player = self.jr.render_at(raster_enemies, state.player_car.position.x, player_screen_y, player_mask) + raster_player = self.jr.render_at_clipped(raster_enemies, state.player_car.position.x, player_screen_y, player_mask) wall_top_mask = self.SHAPE_MASKS["wall_top"] raster_wall_top = self.jr.render_at(raster_player, 0, 0, wall_top_mask) From 523b1d72b1ab7cd7cc31c6faeddf72ea95491413 Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sun, 21 Dec 2025 16:22:10 +0100 Subject: [PATCH 64/76] reuse movement logic for enemys --- src/jaxatari/games/jax_upndown.py | 331 +++++++++++++----------------- 1 file changed, 144 insertions(+), 187 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 316011696..f1b717075 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -295,19 +295,21 @@ def _apply_steep_road_penalty( @partial(jax.jit, static_argnums=(0,)) def _get_slope_and_intercept_from_indices(self, current_road: chex.Array, road_index_A: chex.Array, road_index_B: chex.Array) -> Tuple[chex.Array, chex.Array]: - trackx, tracky, road_index = jax.lax.cond( - current_road == 0, - lambda _: (self.consts.FIRST_TRACK_CORNERS_X, self.consts.TRACK_CORNERS_Y, road_index_A), - lambda _: (self.consts.SECOND_TRACK_CORNERS_X, self.consts.TRACK_CORNERS_Y, road_index_B), - operand=None, - ) - slope = jax.lax.cond( - trackx[road_index+1] - trackx[road_index] != 0, - lambda _: (tracky[road_index+1] - tracky[road_index]) / (trackx[road_index+1] - trackx[road_index]), - lambda _: 300.0, - operand=None, - ) - b = tracky[road_index] - slope * trackx[road_index] + """Calculate slope and intercept for the current road segment.""" + road_index = jnp.where(current_road == 0, road_index_A, road_index_B) + x1 = jnp.where(current_road == 0, + self.consts.FIRST_TRACK_CORNERS_X[road_index], + self.consts.SECOND_TRACK_CORNERS_X[road_index]) + x2 = jnp.where(current_road == 0, + self.consts.FIRST_TRACK_CORNERS_X[road_index + 1], + self.consts.SECOND_TRACK_CORNERS_X[road_index + 1]) + y1 = self.consts.TRACK_CORNERS_Y[road_index] + y2 = self.consts.TRACK_CORNERS_Y[road_index + 1] + + dx = x2 - x1 + dy = y2 - y1 + slope = jnp.where(dx != 0, dy / dx, 300.0) + b = y1 - slope * x1 return slope, b @partial(jax.jit, static_argnums=(0,)) @@ -332,6 +334,24 @@ def _get_x_on_road(self, y: chex.Array, road_segment: chex.Array, track_corners_ t = jnp.where(y2 != y1, (y - y1) / (y2 - y1), 0.0) return x1 + t * (x2 - x1) + @partial(jax.jit, static_argnums=(0,)) + def _get_x_for_road_index(self, y: chex.Array, road_segment: chex.Array, road_index: chex.Array) -> chex.Array: + """Get X position on road A (index 0) or road B (index 1) for given Y and segment.""" + track_corners = jnp.where( + road_index == 0, + self.consts.FIRST_TRACK_CORNERS_X[road_segment], + self.consts.SECOND_TRACK_CORNERS_X[road_segment], + ) + track_corners_next = jnp.where( + road_index == 0, + self.consts.FIRST_TRACK_CORNERS_X[road_segment + 1], + self.consts.SECOND_TRACK_CORNERS_X[road_segment + 1], + ) + y1 = self.consts.TRACK_CORNERS_Y[road_segment] + y2 = self.consts.TRACK_CORNERS_Y[road_segment + 1] + t = jnp.where(y2 != y1, (y - y1) / (y2 - y1), 0.0) + return track_corners + t * (track_corners_next - track_corners) + @partial(jax.jit, static_argnums=(0,)) def _get_road_segment(self, y: chex.Array) -> chex.Array: """Return the road segment index for a given y position.""" @@ -339,6 +359,56 @@ def _get_road_segment(self, y: chex.Array) -> chex.Array: max_idx = jnp.int32(len(self.consts.TRACK_CORNERS_Y) - 1) return jnp.clip(segments - 1, 0, max_idx) + @partial(jax.jit, static_argnums=(0,)) + def _compute_direction_x(self, current_road: chex.Array, road_index_A: chex.Array, road_index_B: chex.Array) -> chex.Array: + """Calculate the X direction for movement on the current road segment. + + Returns: + Direction as int32: -1 for left, 1 for right (defaults to -1 for vertical segments) + """ + # Select the road index based on which road we're on + road_index = jnp.where(current_road == 0, road_index_A, road_index_B) + # Select corners for the current road + x_curr = jnp.where(current_road == 0, + self.consts.FIRST_TRACK_CORNERS_X[road_index], + self.consts.SECOND_TRACK_CORNERS_X[road_index]) + x_next = jnp.where(current_road == 0, + self.consts.FIRST_TRACK_CORNERS_X[road_index + 1], + self.consts.SECOND_TRACK_CORNERS_X[road_index + 1]) + direction_raw = x_next - x_curr + return jnp.where(direction_raw == 0, -1, jnp.sign(direction_raw)).astype(jnp.int32) + + @partial(jax.jit, static_argnums=(0,)) + def _move_on_road( + self, + position: EntityPosition, + slope: chex.Array, + b: chex.Array, + speed_sign: chex.Array, + step_size: chex.Array, + car_direction_x: chex.Array, + move_y: chex.Array, + move_x: chex.Array, + ) -> Tuple[chex.Array, chex.Array]: + """Move a car on the road based on timing and geometry. + + Returns: + Tuple of (new_x, new_y) positions + """ + new_y = jnp.where( + jnp.logical_and(move_y, self._is_on_line_for_position(position, slope, b, speed_sign, 1)), + position.y + speed_sign * -step_size, + position.y, + ) + + new_x = jnp.where( + jnp.logical_and(move_x, self._is_on_line_for_position(position, slope, b, speed_sign, 2)), + position.x + speed_sign * car_direction_x * step_size, + position.x, + ) + + return new_x, new_y + @partial(jax.jit, static_argnums=(0,)) def _is_steep_road_segment(self, current_road: chex.Array, road_index_A: chex.Array, road_index_B: chex.Array) -> chex.Array: """Check if the current road segment is steep (no X direction change). @@ -349,12 +419,14 @@ def _is_steep_road_segment(self, current_road: chex.Array, road_index_A: chex.Ar Returns True if the segment is steep (requires jump to pass when going up). """ # Get the X difference for the current road segment - x_diff = jax.lax.cond( - current_road == 0, - lambda _: jnp.abs(self.consts.FIRST_TRACK_CORNERS_X[road_index_A + 1] - self.consts.FIRST_TRACK_CORNERS_X[road_index_A]), - lambda _: jnp.abs(self.consts.SECOND_TRACK_CORNERS_X[road_index_B + 1] - self.consts.SECOND_TRACK_CORNERS_X[road_index_B]), - operand=None, - ) + road_index = jnp.where(current_road == 0, road_index_A, road_index_B) + x_curr = jnp.where(current_road == 0, + self.consts.FIRST_TRACK_CORNERS_X[road_index], + self.consts.SECOND_TRACK_CORNERS_X[road_index]) + x_next = jnp.where(current_road == 0, + self.consts.FIRST_TRACK_CORNERS_X[road_index + 1], + self.consts.SECOND_TRACK_CORNERS_X[road_index + 1]) + x_diff = jnp.abs(x_next - x_curr) # A segment is steep if there's no X change (or very small change) return x_diff < 1.0 @@ -433,57 +505,26 @@ def _advance_player_car( slope, b = self._get_slope_and_intercept_from_indices(current_road, road_index_A, road_index_B) # Determine X direction based on current road segment (for normal movement) - direction_raw = jax.lax.cond( - current_road == 0, - lambda _: self.consts.FIRST_TRACK_CORNERS_X[road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_X[road_index_A], - lambda _: self.consts.SECOND_TRACK_CORNERS_X[road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[road_index_B], - operand=None, - ) - # Use sign, default to -1 for zero (vertical segments) - car_direction_x = jnp.where(direction_raw == 0, -1, jnp.sign(direction_raw)).astype(jnp.int32) + car_direction_x = self._compute_direction_x(current_road, road_index_A, road_index_B) position = EntityPosition(x=position_x, y=position_y, width=width, height=height) + # === CALCULATE ROAD-BASED MOVEMENT (used when not jumping) === + road_x, road_y = self._move_on_road( + position, slope, b, speed_sign, step_size, car_direction_x, move_y, move_x + ) + # === Y MOVEMENT === # When jumping: move freely in Y direction - # When on road: only move if allowed by road geometry - new_player_y = jax.lax.cond( - move_y, - lambda _: jax.lax.cond( - is_jumping, - lambda _: position_y + speed_sign * -step_size, # Free movement while jumping - lambda _: jax.lax.cond( - self._is_on_line_for_position(position, slope, b, speed_sign, 1), - lambda _: position_y + speed_sign * -step_size, - lambda _: jnp.array(position_y, float), - operand=None, - ), - operand=None, - ), - lambda _: jnp.array(position_y, float), - operand=None, - ) + # When on road: use road-based movement result + jump_y = jnp.where(move_y, position_y + speed_sign * -step_size, position_y) + new_player_y = jnp.where(is_jumping, jump_y, road_y) # === X MOVEMENT === # When jumping: use stored_jump_slope (locked at jump start) - moves X proportionally to Y - # The slope already encodes direction (dx/dy), so multiply by Y step size and speed_sign - # When on road: only move if allowed by road geometry - new_player_x = jax.lax.cond( - move_x, - lambda _: jax.lax.cond( - is_jumping, - lambda _: position_x - speed_sign * stored_jump_slope * step_size, # Slope-based movement (negated because Y decreases going forward) - lambda _: jax.lax.cond( - self._is_on_line_for_position(position, slope, b, speed_sign, 2), - lambda _: position_x + speed_sign * car_direction_x * step_size, # Normal road movement - lambda _: jnp.array(position_x, float), - operand=None, - ), - operand=None, - ), - lambda _: jnp.array(position_x, float), - operand=None, - ) + # When on road: use road-based movement result + jump_x = jnp.where(move_x, position_x - speed_sign * stored_jump_slope * step_size, position_x) + new_player_x = jnp.where(is_jumping, jump_x, road_x) # === LANDING LOGIC === # Get the current road segment based on new Y position @@ -524,60 +565,21 @@ def _advance_player_car( landing_in_water = jnp.logical_and(is_landing, jnp.logical_not(valid_landing)) # === UPDATE ROAD STATE === - # Determine which road to assign on landing - landed_road = jax.lax.cond( - on_road_A, - lambda _: jnp.int32(0), - lambda _: jax.lax.cond( - on_road_B, - lambda _: jnp.int32(1), - lambda _: nearest_road_id, # Between roads - use nearest - operand=None, - ), - operand=None, - ) + # Determine which road to assign on landing (priority: road A > road B > nearest) + landed_road = jnp.where(on_road_A, jnp.int32(0), jnp.where(on_road_B, jnp.int32(1), nearest_road_id)) - # Update current_road - # - If landing in water: set to 2 (water/crash marker) - # - If landing successfully: set to the landed road - # - If still jumping: keep current road (frozen during jump) - # - If on road normally: update based on position - updated_current_road = jax.lax.cond( - landing_in_water, - lambda _: jnp.int32(2), # Water crash - lambda _: jax.lax.cond( - is_landing, - lambda _: landed_road, # Successfully landed - lambda _: jax.lax.cond( - is_jumping, - lambda _: current_road, # Keep road frozen while jumping - lambda _: jax.lax.cond( - current_road == 2, - lambda _: nearest_road_id, # Recover from water state - lambda _: current_road, # Normal on-road movement - operand=None, - ), - operand=None, - ), - operand=None, - ), - operand=None, - ) + # Update current_road using nested jnp.where for vectorized execution + # Priority: water crash > landing > jumping (frozen) > recover from water > normal + normal_road = jnp.where(current_road == 2, nearest_road_id, current_road) + jumping_road = jnp.where(is_jumping, current_road, normal_road) + landing_road = jnp.where(is_landing, landed_road, jumping_road) + updated_current_road = jnp.where(landing_in_water, jnp.int32(2), landing_road) # Update road indices to match current segment when not jumping - next_road_index_A = jax.lax.cond( - jnp.logical_and(jnp.logical_not(is_jumping), updated_current_road == 0), - lambda _: segment, - lambda _: road_index_A, - operand=None, - ) - - next_road_index_B = jax.lax.cond( - jnp.logical_and(jnp.logical_not(is_jumping), updated_current_road == 1), - lambda _: segment, - lambda _: road_index_B, - operand=None, - ) + not_jumping_on_road_A = jnp.logical_and(jnp.logical_not(is_jumping), updated_current_road == 0) + not_jumping_on_road_B = jnp.logical_and(jnp.logical_not(is_jumping), updated_current_road == 1) + next_road_index_A = jnp.where(not_jumping_on_road_A, segment, road_index_A) + next_road_index_B = jnp.where(not_jumping_on_road_B, segment, road_index_B) # Wrap Y position for looping track wrapped_y = -((new_player_y * -1) % self.consts.TRACK_LENGTH) @@ -614,42 +616,14 @@ def _advance_car_core( """Simplified car advancement for enemy cars (no jumping/landing logic).""" # Calculate movement timing using helper move_y, move_x, step_size, speed_sign = self._compute_movement_timing(speed, step_counter) - slope, b = self._get_slope_and_intercept_from_indices(current_road, road_index_A, road_index_B) - - direction_raw = jax.lax.cond( - current_road == 0, - lambda _: self.consts.FIRST_TRACK_CORNERS_X[road_index_A+1] - self.consts.FIRST_TRACK_CORNERS_X[road_index_A], - lambda _: self.consts.SECOND_TRACK_CORNERS_X[road_index_B+1] - self.consts.SECOND_TRACK_CORNERS_X[road_index_B], - operand=None, - ) - # Use sign, default to -1 for zero (vertical segments) - car_direction_x = jnp.where(direction_raw == 0, -1, jnp.sign(direction_raw)).astype(jnp.int32) - + car_direction_x = self._compute_direction_x(current_road, road_index_A, road_index_B) + position = EntityPosition(x=position_x, y=position_y, width=width, height=height) - - new_y = jax.lax.cond( - move_y, - lambda _: jax.lax.cond( - self._is_on_line_for_position(position, slope, b, speed_sign, 1), - lambda _: position_y + speed_sign * -step_size, - lambda _: jnp.array(position_y, float), - operand=None, - ), - lambda _: jnp.array(position_y, float), - operand=None, - ) - - new_x = jax.lax.cond( - move_x, - lambda _: jax.lax.cond( - self._is_on_line_for_position(position, slope, b, speed_sign, 2), - lambda _: position_x + speed_sign * car_direction_x * step_size, - lambda _: jnp.array(position_x, float), - operand=None, - ), - lambda _: jnp.array(position_x, float), - operand=None, + + # Use shared movement helper + new_x, new_y = self._move_on_road( + position, slope, b, speed_sign, step_size, car_direction_x, move_y, move_x ) wrapped_y = -((new_y * -1) % self.consts.TRACK_LENGTH) @@ -707,10 +681,7 @@ def check_flag_collision(flag_idx): # Check if player is close enough to collect the flag y_distance = jnp.abs(new_player_y - flag_y) x_distance = jnp.abs(player_x - flag_x) - same_road = jnp.logical_or( - jnp.logical_and(current_road == 0, flag_road == 0), - jnp.logical_and(current_road == 1, flag_road == 1), - ) + same_road = (current_road == flag_road) collision = jnp.logical_and( jnp.logical_and(y_distance < self.consts.COLLISION_THRESHOLD, x_distance < self.consts.COLLISION_THRESHOLD), @@ -825,10 +796,7 @@ def check_collision(idx): y_distance = jnp.abs(new_player_y - c_y) x_distance = jnp.abs(player_x - c_x) - same_road = jnp.logical_or( - jnp.logical_and(current_road == 0, c_road == 0), - jnp.logical_and(current_road == 1, c_road == 1), - ) + same_road = (current_road == c_road) collision = jnp.logical_and( jnp.logical_and(y_distance < self.consts.COLLISION_THRESHOLD, x_distance < self.consts.COLLISION_THRESHOLD), @@ -1709,27 +1677,22 @@ def _createBackgroundSprite(self, dimensions: Tuple[int, int]) -> jnp.ndarray: self._num_road_tiles = int(self._road_tile_offsets.shape[0]) self.enemy_sprite_names = { - self.consts.ENEMY_TYPE_CAMERO: ("camero_left", "camero_right"), - self.consts.ENEMY_TYPE_FLAG_CARRIER: ("flag_carrier_left", "flag_carrier_right"), - self.consts.ENEMY_TYPE_PICKUP: ("pick_up_truck_left", "pick_up_truck_right"), - self.consts.ENEMY_TYPE_TRUCK: ("truck_left", "truck_right"), + self.consts.ENEMY_TYPE_CAMERO: "camero_left", + self.consts.ENEMY_TYPE_FLAG_CARRIER: "flag_carrier_left", + self.consts.ENEMY_TYPE_PICKUP: "pick_up_truck_left", + self.consts.ENEMY_TYPE_TRUCK: "truck_left", } # Pre-pad enemy masks to a common shape so switch/array indexing works under jit + # Only use left sprites - right sprites are created by flipping horizontally enemy_left_raw = [ self.SHAPE_MASKS["camero_left"], self.SHAPE_MASKS["flag_carrier_left"], self.SHAPE_MASKS["pick_up_truck_left"], self.SHAPE_MASKS["truck_left"], ] - enemy_right_raw = [ - self.SHAPE_MASKS["camero_right"], - self.SHAPE_MASKS["flag_carrier_right"], - self.SHAPE_MASKS["pick_up_truck_right"], - self.SHAPE_MASKS["truck_right"], - ] - max_h = max([m.shape[0] for m in enemy_left_raw + enemy_right_raw]) - max_w = max([m.shape[1] for m in enemy_left_raw + enemy_right_raw]) + max_h = max([m.shape[0] for m in enemy_left_raw]) + max_w = max([m.shape[1] for m in enemy_left_raw]) def _pad_mask(mask): pad_h = max_h - mask.shape[0] @@ -1737,7 +1700,8 @@ def _pad_mask(mask): return jnp.pad(mask, ((0, pad_h), (0, pad_w)), constant_values=self.jr.TRANSPARENT_ID) self.enemy_left_masks = jnp.stack([_pad_mask(m) for m in enemy_left_raw], axis=0) - self.enemy_right_masks = jnp.stack([_pad_mask(m) for m in enemy_right_raw], axis=0) + # Create right-facing masks by horizontally flipping the left masks + self.enemy_right_masks = jnp.flip(self.enemy_left_masks, axis=2) # Precompute flag mask data for recoloring without special-casing pink self.flag_base_mask = self.SHAPE_MASKS["pink_flag"] @@ -1786,6 +1750,15 @@ def _get_road_sprite_sizes(self, road_files: list[str]) -> list: complete_size = int(sum(sizes)) return sizes, complete_size + def _get_x_on_road(self, y: chex.Array, road_segment: chex.Array, track_corners_x: chex.Array) -> chex.Array: + """Calculate the X position on a road given a Y coordinate and road segment.""" + y1 = self.consts.TRACK_CORNERS_Y[road_segment] + y2 = self.consts.TRACK_CORNERS_Y[road_segment + 1] + x1 = track_corners_x[road_segment] + x2 = track_corners_x[road_segment + 1] + t = jnp.where(y2 != y1, (y - y1) / (y2 - y1), 0.0) + return x1 + t * (x2 - x1) + def _find_palette_id(self, rgba: jnp.ndarray) -> int: """Return palette index for an RGBA color, falling back to first entry if missing.""" color_rgb = rgba[:3] @@ -1808,20 +1781,6 @@ def _jump_arc_offset(self, jump_cooldown: chex.Array) -> chex.Array: centered = (progress - 0.5) * 2.0 return self.consts.JUMP_ARC_HEIGHT * (1.0 - centered * centered) - def _get_x_on_road(self, y: chex.Array, road_segment: chex.Array, track_corners_x: chex.Array) -> chex.Array: - """Linear interpolation of x along the given road segment for y.""" - y1 = self.consts.TRACK_CORNERS_Y[road_segment] - y2 = self.consts.TRACK_CORNERS_Y[road_segment + 1] - x1 = track_corners_x[road_segment] - x2 = track_corners_x[road_segment + 1] - t = jax.lax.cond( - y2 != y1, - lambda _: (y - y1) / (y2 - y1), - lambda _: 0.0, - operand=None, - ) - return x1 + t * (x2 - x1) - def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.ndarray, bottomBlockSprite: jnp.ndarray, tempPointer: jnp.ndarray, blackoutSquare: jnp.ndarray) -> tuple[list, list[str]]: """Returns the asset manifest and ordered road files.""" road_dir = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/roads" @@ -1834,14 +1793,11 @@ def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.n {'name': 'background', 'type': 'background', 'data': backgroundSprite}, {'name': 'road', 'type': 'group', 'files': roads}, {'name': 'player', 'type': 'single', 'file': 'player_car.npy'}, + # Only load left-facing enemy sprites; right-facing are created by flipping {'name': 'camero_left', 'type': 'single', 'file': 'enemy_cars/camero_left.npy'}, - {'name': 'camero_right', 'type': 'single', 'file': 'enemy_cars/camero_right.npy'}, {'name': 'flag_carrier_left', 'type': 'single', 'file': 'enemy_cars/flag_carrier_left.npy'}, - {'name': 'flag_carrier_right', 'type': 'single', 'file': 'enemy_cars/flag_carrier_right.npy'}, {'name': 'pick_up_truck_left', 'type': 'single', 'file': 'enemy_cars/pick_up_truck_left.npy'}, - {'name': 'pick_up_truck_right', 'type': 'single', 'file': 'enemy_cars/pick_up_truck_right.npy'}, {'name': 'truck_left', 'type': 'single', 'file': 'enemy_cars/truck_left.npy'}, - {'name': 'truck_right', 'type': 'single', 'file': 'enemy_cars/truck_right.npy'}, {'name': 'wall_top', 'type': 'procedural', 'data': topBlockSprite}, {'name': 'wall_bottom', 'type': 'procedural', 'data': bottomBlockSprite}, {'name': 'all_flags_top', 'type': 'single', 'file': 'all_flags_top.npy'}, @@ -1908,9 +1864,10 @@ def combine(i, acc): raster = jax.lax.fori_loop(0, total_segments, combine, raster) def select_enemy_mask(enemy_type: chex.Array, going_left: chex.Array): + """Select enemy mask: left masks are base, right masks are horizontally flipped.""" left_mask = self.enemy_left_masks[enemy_type] right_mask = self.enemy_right_masks[enemy_type] - return jax.lax.cond(going_left, lambda _: left_mask, lambda _: right_mask, operand=None) + return jnp.where(going_left, left_mask, right_mask) def render_enemy(carry, enemy_idx): raster = carry From 5b84dab7b8d7d6afa97ca783bceb41094a62629c Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sun, 21 Dec 2025 16:32:00 +0100 Subject: [PATCH 65/76] reworked behaivor on step sections --- src/jaxatari/games/jax_upndown.py | 118 +++++++++++++++++++++--------- 1 file changed, 84 insertions(+), 34 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index f1b717075..005b8e0bf 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -430,6 +430,31 @@ def _is_steep_road_segment(self, current_road: chex.Array, road_index_A: chex.Ar # A segment is steep if there's no X change (or very small change) return x_diff < 1.0 + @partial(jax.jit, static_argnums=(0,)) + def _get_steep_segment_progress(self, position_y: chex.Array, current_road: chex.Array, + road_index_A: chex.Array, road_index_B: chex.Array) -> chex.Array: + """Calculate progress (0.0 to 1.0) through the current steep road segment. + + 0.0 = at the bottom (start) of the steep segment + 1.0 = at the top (end) of the steep segment + + Progress is measured in the direction of forward travel (upward = positive Y direction in game space, + but Y decreases as we go forward on the track). + """ + road_index = jnp.where(current_road == 0, road_index_A, road_index_B) + # Y coordinates of segment boundaries + y_start = self.consts.TRACK_CORNERS_Y[road_index] # Start of segment (lower Y = further ahead) + y_end = self.consts.TRACK_CORNERS_Y[road_index + 1] # End of segment (higher Y in absolute terms) + + # Calculate progress: how far through the segment are we? + # Since Y decreases as we go forward, we need to invert + segment_length = jnp.abs(y_end - y_start) + # Distance from segment start (in forward direction) + distance_from_start = jnp.abs(position_y - y_start) + + progress = jnp.where(segment_length > 0.001, distance_from_start / segment_length, 0.0) + return jnp.clip(progress, 0.0, 1.0) + @partial(jax.jit, static_argnums=(0,)) def _check_landing_position( self, @@ -871,65 +896,90 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: up = jnp.logical_or(action == Action.UP, action == Action.UPFIRE) down = jnp.logical_or(action == Action.DOWN, action == Action.DOWNFIRE) jump = jnp.logical_or(action == Action.FIRE, jnp.logical_or(action == Action.UPFIRE, action == Action.DOWNFIRE)) - player_speed = state.player_car.speed.astype(jnp.float32) - - # Use jnp.where for branchless execution + + # Check if on a steep road section FIRST (before applying speed changes) + is_on_steep_road = self._is_steep_road_segment( + state.player_car.current_road, + state.player_car.road_index_A, + state.player_car.road_index_B, + ) + + # Calculate progress through steep segment (0.0 = bottom, 1.0 = top) + steep_progress = self._get_steep_segment_progress( + state.player_car.position.y, + state.player_car.current_road, + state.player_car.road_index_A, + state.player_car.road_index_B, + ) + + # Determine if player is on steep road going up (not jumping) + on_steep_not_jumping = jnp.logical_and(is_on_steep_road, jnp.logical_not(state.is_jumping)) + + # Start with current speed + player_speed = state.player_car.speed + + # === STEEP ROAD BLOCKING LOGIC === + # On steep road: UP action has NO effect (can't accelerate while on steep section) + # Apply UP acceleration only if NOT on steep road (or if jumping over it) + can_accelerate = jnp.logical_not(on_steep_not_jumping) player_speed = jnp.where( - jnp.logical_and(state.player_car.speed < self.consts.MAX_SPEED, up), + jnp.logical_and(jnp.logical_and(player_speed < self.consts.MAX_SPEED, up), can_accelerate), player_speed + 1, player_speed, ) - + + # DOWN action always works (can brake/reverse) player_speed = jnp.where( - jnp.logical_and(state.player_car.speed > -self.consts.MAX_SPEED, down), + jnp.logical_and(player_speed > -self.consts.MAX_SPEED, down), player_speed - 1, player_speed, ) - - # Check if on a steep road section (no X direction change) and apply speed reduction - # This simulates steep road sections that require a jump to pass when going upward - is_on_steep_road = self._is_steep_road_segment( - state.player_car.current_road, - state.player_car.road_index_A, - state.player_car.road_index_B, - ) - # Only apply steep road penalty when: - # 1. Player is on a steep road section - # 2. Player is not jumping - # 3. Player has positive speed (going upward) - player_speed, steep_road_timer, jump_boost_multiplier = self._apply_steep_road_penalty( - player_speed, is_on_steep_road, state.steep_road_timer, state.is_jumping, state.jump_cooldown - ) - on_steep_going_up = jnp.logical_and( - is_on_steep_road, - jnp.logical_and( - jnp.logical_not(state.is_jumping), - player_speed > 0 - ) - ) - # Update steep road timer - increment when on steep road going up, reset otherwise (use jnp.where) + + # === STEEP ROAD SPEED REDUCTION & SLIDE BACK === + # Only apply when on steep road, not jumping, and trying to go up (positive speed) + on_steep_going_up = jnp.logical_and(on_steep_not_jumping, player_speed > 0) + + # Update steep road timer - increment when on steep road going up steep_road_timer = jnp.where( on_steep_going_up, state.steep_road_timer + 1, jnp.array(0, dtype=jnp.int32), ) - # Only reduce speed when timer reaches the interval threshold + + # Check if player has reached halfway point (50% progress through segment) + past_halfway = steep_progress >= 0.5 + + # Two behaviors based on progress: + # 1. Before halfway: gradually reduce speed using timer + # 2. At/past halfway: immediately set speed to -2 (slide back) + + # Before halfway: reduce speed periodically using timer should_reduce_speed = jnp.logical_and( on_steep_going_up, - steep_road_timer >= self.consts.STEEP_ROAD_SPEED_REDUCTION_INTERVAL + jnp.logical_and( + jnp.logical_not(past_halfway), + steep_road_timer >= self.consts.STEEP_ROAD_SPEED_REDUCTION_INTERVAL + ) ) - # Gradually reduce speed toward -2 when on steep section without jumping (use jnp.where) player_speed = jnp.where( should_reduce_speed, - jnp.maximum(player_speed - 1, jnp.int32(-2)), + jnp.maximum(player_speed - 1, jnp.int32(0)), # Reduce but not below 0 yet player_speed, ) - # Reset timer after speed reduction (use jnp.where) + # Reset timer after speed reduction steep_road_timer = jnp.where( should_reduce_speed, jnp.array(0, dtype=jnp.int32), steep_road_timer, ) + + # At/past halfway: force speed to -2 (slide back down) + should_slide_back = jnp.logical_and(on_steep_going_up, past_halfway) + player_speed = jnp.where( + should_slide_back, + jnp.int32(-3), + player_speed, + ) can_start_jump = jnp.logical_and(state.jump_cooldown == 0, state.post_jump_cooldown == 0) is_jumping = jnp.logical_or( From 29936f9364a87b543f78f0ef6d46989da73b219e Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sun, 21 Dec 2025 17:13:18 +0100 Subject: [PATCH 66/76] move RNG key up and make functions jittable --- src/jaxatari/games/jax_upndown.py | 63 ++++++++++++++++--------------- 1 file changed, 32 insertions(+), 31 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 005b8e0bf..a99f23f4e 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -165,6 +165,7 @@ class UpNDownState(NamedTuple): is_dead: chex.Array respawn_timer: chex.Array step_counter: chex.Array + rng_key: chex.PRNGKey round_started: chex.Array movement_steps: chex.Array steep_road_timer: chex.Array # Timer for steep road speed reduction @@ -187,8 +188,6 @@ class UpNDownObservation(NamedTuple): class UpNDownInfo(NamedTuple): time: jnp.ndarray - - class JaxUpNDown(JaxEnvironment[UpNDownState, UpNDownObservation, UpNDownInfo, UpNDownConstants]): def __init__(self, consts: UpNDownConstants = None, reward_funcs: list[callable]=None): consts = consts or UpNDownConstants() @@ -734,9 +733,9 @@ def check_flag_collision(flag_idx): return new_flags, flag_score, new_flags_collected_mask @partial(jax.jit, static_argnums=(0,)) - def _collectible_step(self, state: UpNDownState, new_player_y: chex.Array, player_x: chex.Array, current_road: chex.Array) -> Tuple[Collectible, chex.Array, chex.Array]: + def _collectible_step(self, state: UpNDownState, new_player_y: chex.Array, player_x: chex.Array, current_road: chex.Array, rng_key: chex.PRNGKey) -> Tuple[Collectible, chex.Array, chex.Array, chex.PRNGKey]: """Update collectible spawning, despawning, and collection (unified for all types). - + Handles mixed-type collectibles (cherry, balloon, lollypop, ice cream) in a single pool. Type is randomized on spawn with probabilities defined in COLLECTIBLE_SPAWN_PROBABILITIES. @@ -745,10 +744,13 @@ def _collectible_step(self, state: UpNDownState, new_player_y: chex.Array, playe new_player_y: Updated player Y position after movement player_x: Current player X position current_road: Current road player is on + rng_key: PRNG key to drive spawn randomness Returns: - Tuple of (updated_collectibles, score_delta, new_spawn_timer) + Tuple of (updated_collectibles, score_delta, new_spawn_timer, new_rng_key) """ + rng_key, key1, key2, key3, key4 = jax.random.split(rng_key, 5) + # Collectible spawning logic - decrement timer and spawn when ready (use jnp.where for branchless) new_collectible_timer = jnp.where( state.collectible_spawn_timer <= 0, @@ -764,10 +766,6 @@ def _collectible_step(self, state: UpNDownState, new_player_y: chex.Array, playe has_inactive_slot = jnp.any(inactive_mask) spawn_idx = jnp.where(has_inactive_slot, first_inactive, jnp.array(0, dtype=jnp.int32)) - # Generate random spawn position using fold_in for deterministic randomness - base_key = jax.random.PRNGKey(0) - key_for_spawn = jax.random.fold_in(base_key, state.step_counter) - key1, key2, key3, key4 = jax.random.split(key_for_spawn, 4) y_spawn = jax.random.uniform(key1, minval=-900.0, maxval=-100.0) road_spawn = jnp.array(jax.random.randint(key2, shape=(), minval=0, maxval=2), dtype=jnp.int32) color_spawn = jnp.array(jax.random.randint(key3, shape=(), minval=0, maxval=len(self.consts.COLLECTIBLE_COLORS)), dtype=jnp.int32) @@ -848,7 +846,7 @@ def check_collision(idx): active=final_active, ) - return updated_collectibles, score_delta, new_collectible_timer + return updated_collectibles, score_delta, new_collectible_timer, rng_key @partial(jax.jit, static_argnums=(0,)) def _death_step(self, state: UpNDownState) -> UpNDownState: @@ -1125,14 +1123,15 @@ def _collectible_step_main(self, state: UpNDownState) -> UpNDownState: player_x = state.player_car.position.x current_road = state.player_car.current_road - updated_collectibles, collectible_score, new_collectible_timer = self._collectible_step( - state, new_player_y, player_x, current_road + updated_collectibles, collectible_score, new_collectible_timer, rng_key = self._collectible_step( + state, new_player_y, player_x, current_road, state.rng_key ) return state._replace( score=state.score + collectible_score, collectibles=updated_collectibles, collectible_spawn_timer=new_collectible_timer, + rng_key=rng_key, ) @partial(jax.jit, static_argnums=(0,)) @@ -1205,9 +1204,7 @@ def init_direction(seg, road): @partial(jax.jit, static_argnums=(0,)) def _enemy_step_main(self, state: UpNDownState) -> UpNDownState: """Spawn and move enemy cars with adaptive spawning for consistent enemy presence.""" - base_key = jax.random.PRNGKey(2025) - step_key = jax.random.fold_in(base_key, state.step_counter) - key_spawn_offset, key_spawn_side, key_spawn_speed, key_spawn_direction, key_spawn_type, key_spawn_sign, key_flip_root = jax.random.split(step_key, 7) + rng_key, key_spawn_offset, key_spawn_side, key_spawn_speed, key_spawn_direction, key_spawn_type, key_spawn_sign, key_flip_root = jax.random.split(state.rng_key, 8) active_mask = state.enemy_cars.active active_count = jnp.sum(active_mask.astype(jnp.int32)) @@ -1361,14 +1358,13 @@ def _enemy_step_main(self, state: UpNDownState) -> UpNDownState: return state._replace( enemy_cars=next_enemy_cars, enemy_spawn_timer=spawn_timer, + rng_key=rng_key, ) @partial(jax.jit, static_argnums=(0,)) def _respawn_after_collision(self, state: UpNDownState, new_lives: chex.Array) -> UpNDownState: """Respawn the player on a random road while preserving score and flags.""" - base_key = jax.random.PRNGKey(1337) - key_spawn = jax.random.fold_in(base_key, state.step_counter) - road_key, enemy_key = jax.random.split(key_spawn, 2) + rng_key, road_key, enemy_key = jax.random.split(state.rng_key, 3) player_start_y = jnp.array(0.0) start_segment = jnp.array(0, dtype=jnp.int32) @@ -1421,6 +1417,7 @@ def _respawn_after_collision(self, state: UpNDownState, new_lives: chex.Array) - collectible_spawn_timer=jnp.array(0, dtype=jnp.int32), enemy_cars=enemy_cars, enemy_spawn_timer=jnp.array(self.consts.ENEMY_SPAWN_INTERVAL_BASE, dtype=jnp.int32), + rng_key=rng_key, ) @partial(jax.jit, static_argnums=(0,)) @@ -1515,15 +1512,11 @@ def _passive_score_step_main(self, state: UpNDownState) -> UpNDownState: bonus = jnp.where(should_award, jnp.int32(self.consts.PASSIVE_SCORE_AMOUNT), jnp.int32(0)) return state._replace(score=state.score + bonus) + + @partial(jax.jit, static_argnums=(0,)) + def _reset_jit(self, key: chex.PRNGKey) -> Tuple[UpNDownObservation, UpNDownState]: + rng_key, flag_key, enemy_key = jax.random.split(key, 3) - - def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: - # Initialize flags at random positions along the track - # Use key for randomness if provided, otherwise use default positions - if key is None: - key = jax.random.PRNGKey(42) - - key, flag_key, enemy_key = jax.random.split(key, 3) # Evenly spread flags along the track with small jitter base_y = jnp.linspace(-900.0, -100.0, self.consts.NUM_FLAGS) jitter = jax.random.uniform(flag_key, shape=(self.consts.NUM_FLAGS,), minval=-40.0, maxval=40.0) @@ -1531,13 +1524,13 @@ def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: # Alternate roads 0/1 for variety flag_roads = jnp.arange(self.consts.NUM_FLAGS) % 2 - + # Calculate which road segment each flag is on based on Y position flag_segments = jax.vmap(self._get_road_segment)(flag_y_offsets) - + # Each flag color index corresponds to its position (0-7) flag_color_indices = jnp.arange(self.consts.NUM_FLAGS) - + flags = Flag( y=flag_y_offsets, road=flag_roads, @@ -1545,14 +1538,14 @@ def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: color_idx=flag_color_indices, collected=jnp.zeros(self.consts.NUM_FLAGS, dtype=jnp.bool_), ) - + # Initialize collectibles as all inactive (will spawn dynamically with mixed types) collectibles = self._initialize_collectibles() # Seed initial visible enemies spaced around the player player_start_y = jnp.array(0.0) enemy_cars = self._initialize_enemies(enemy_key, player_start_y) - + state = UpNDownState( score=0, lives=jnp.array(self.consts.INITIAL_LIVES, dtype=jnp.int32), @@ -1578,6 +1571,7 @@ def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: type=0, ), step_counter=jnp.array(0), + rng_key=rng_key, round_started=jnp.array(False), movement_steps=jnp.array(0), steep_road_timer=jnp.array(0, dtype=jnp.int32), @@ -1591,6 +1585,12 @@ def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: ) initial_obs = self._get_observation(state) return initial_obs, state + + @partial(jax.jit, static_argnums=(0,)) + def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: + if key is None: + key = jax.random.PRNGKey(42) + return self._reset_jit(key) @partial(jax.jit, static_argnums=(0,)) def step(self, state: UpNDownState, action: chex.Array) -> Tuple[UpNDownObservation, UpNDownState, float, bool, UpNDownInfo]: @@ -1617,6 +1617,7 @@ def render(self, state: UpNDownState) -> jnp.ndarray: frame = self.renderer.render(state) return jnp.asarray(frame, dtype=jnp.uint8) + @partial(jax.jit, static_argnums=(0,)) def _get_observation(self, state: UpNDownState): # Clamp to screen-friendly coordinates so observation_space.contains passes x = jnp.int32(jnp.clip(state.player_car.position.x, 0, 160)) From 835b3627713477767956c9c0975a4200606a893d Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sun, 21 Dec 2025 18:34:10 +0100 Subject: [PATCH 67/76] add confirmation logic to start a new round --- src/jaxatari/games/jax_upndown.py | 503 +++++++++++++++++++++++------- 1 file changed, 396 insertions(+), 107 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index a99f23f4e..ca04cba7c 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -28,14 +28,14 @@ class UpNDownConstants(NamedTuple): ENEMY_SPAWN_INTERVAL_MAX: int = 60 # Max spawn interval when many enemies exist ENEMY_MIN_VISIBLE_COUNT: int = 2 # Minimum enemies to keep on screen ENEMY_VISIBLE_DISTANCE: int = 120 # Distance within which enemies are considered "visible" - ENEMY_DESPAWN_DISTANCE: int = 250 + ENEMY_DESPAWN_DISTANCE: int = 250 ENEMY_SPEED_MIN: int = 3 ENEMY_SPEED_MAX: int = 5 ENEMY_DIRECTION_SWITCH_PROB: float = 0.0001 - ENEMY_SPAWN_OFFSET_MIN: float = 70.0 # Closer spawn distance + ENEMY_SPAWN_OFFSET_MIN: float = 70.0 # Closer spawn distance ENEMY_SPAWN_OFFSET_MAX: float = 130.0 # Max spawn offset ENEMY_MIN_SPAWN_GAP: float = 25.0 # Reduced gap between spawns - ENEMY_MAX_AGE: int = 1900 + ENEMY_MAX_AGE: int = 1900 INITIAL_ENEMY_COUNT: int = 4 INITIAL_ENEMY_BASE_OFFSET: float = 35.0 # Closer initial enemies INITIAL_ENEMY_GAP: float = 25.0 # Tighter initial spacing @@ -59,9 +59,9 @@ class UpNDownConstants(NamedTuple): PASSIVE_SCORE_AMOUNT: int = 10 # Points awarded for passive scoring COLLISION_THRESHOLD: float = 5.0 # Distance threshold for flag/collectible collision TRACK_LENGTH: int = 1036 - FIRST_TRACK_CORNERS_X: chex.Array = jnp.array([30, 75, 128, 75, 21, 75, 131, 111, 150, 95, 150, 115, 150, 108, 150, 115, 115, 75, 18, 38, 67, 38, 38, 20, 64, 30]) + FIRST_TRACK_CORNERS_X: chex.Array = jnp.array([30, 75, 128, 75, 21, 75, 131, 111, 150, 95, 150, 115, 150, 108, 150, 115, 115, 75, 18, 38, 67, 38, 38, 20, 64, 30]) TRACK_CORNERS_Y: chex.Array = jnp.array([0, -40, -98, -155, -203, -268, -327, -347, -382, -467, -525, -565, -597, -625, -670, -705, -738, -788, -838, -862, -898, -925, -950, -972, -1000, -1035]) - SECOND_TRACK_CORNERS_X: chex.Array = jnp.array([115, 75, 20, 75, 133, 75, 22, 37, 63, 27, 66, 30, 63, 24, 60, 38, 38, 75, 131, 111, 150, 118, 118, 98, 150, 115]) + SECOND_TRACK_CORNERS_X: chex.Array = jnp.array([115, 75, 20, 75, 133, 75, 22, 37, 63, 27, 66, 30, 63, 24, 60, 38, 38, 75, 131, 111, 150, 118, 118, 98, 150, 115]) PLAYER_SIZE: Tuple[int, int] = (4, 16) INITIAL_ROAD_POS_Y: int = 25 # Flag constants - 8 flags with different colors matching the top row @@ -69,7 +69,7 @@ class UpNDownConstants(NamedTuple): # Flag colors as RGBA values (matching the top row from left to right) FLAG_COLORS: chex.Array = jnp.array([ [184, 50, 50, 255], # Red - [181, 83, 40, 255], # Orange + [181, 83, 40, 255], # Orange [162, 98, 33, 255], # Dark orange [134, 134, 29, 255], # Yellow/olive [200, 72, 72, 255], # Pink (original) @@ -179,15 +179,46 @@ class UpNDownState(NamedTuple): # Enemy cars - dynamic spawning and movement enemy_cars: EnemyCars enemy_spawn_timer: chex.Array + # Death/respawn state - player is dead and waiting for input to respawn + awaiting_respawn: chex.Array # True when player died and is waiting for input + # Round start state - everything frozen and hidden until player presses input + awaiting_round_start: chex.Array # True at game start and after respawn until input received + # Input debounce - requires button release before next input triggers round start + input_released: chex.Array # True when player has released all buttons since last state change class UpNDownObservation(NamedTuple): - player: EntityPosition + """Complete observation for RL agents in Up N Down. + + Reuses existing game classes for consistency: + - player_car: Car with EntityPosition, speed, type, road info + - enemy_cars: EnemyCars pool with positions, speeds, types, active flags + - flags: Flag with y, road, segment, color, collected status + - collectibles: Collectible with positions, types, active status + - Additional game state: score, lives, jumping status, etc. + """ + player_car: Car # Reuse existing Car class + enemy_cars: EnemyCars # Reuse existing EnemyCars class + flags: Flag # Reuse existing Flag class + collectibles: Collectible # Reuse existing Collectible class + flags_collected_mask: jnp.ndarray # Shape (NUM_FLAGS,) - boolean mask + player_score: jnp.ndarray + lives: jnp.ndarray + is_jumping: jnp.ndarray # Whether player is currently jumping + jump_cooldown: jnp.ndarray # Frames remaining in jump + is_on_steep_road: jnp.ndarray # Whether currently on steep section + round_started: jnp.ndarray # Whether player has started moving + class UpNDownInfo(NamedTuple): - time: jnp.ndarray + """Additional info for debugging and analysis.""" + step_counter: jnp.ndarray # Total steps taken + difficulty: jnp.ndarray # Current difficulty level + movement_steps: jnp.ndarray # Steps since round started + jump_slope: jnp.ndarray # Current jump trajectory slope + player_road_segment: jnp.ndarray # Current road segment index class JaxUpNDown(JaxEnvironment[UpNDownState, UpNDownObservation, UpNDownInfo, UpNDownConstants]): def __init__(self, consts: UpNDownConstants = None, reward_funcs: list[callable]=None): consts = consts or UpNDownConstants() @@ -204,7 +235,22 @@ def __init__(self, consts: UpNDownConstants = None, reward_funcs: list[callable] Action.DOWN, Action.DOWNFIRE, ] - self.obs_size = 3*4+1+1 + # Calculate obs_size based on observation structure: + # Player car: 8 values (x, y, w, h, speed, type, road, direction_x) + # Enemy cars: MAX_ENEMY_CARS * 8 = 8 * 8 = 64 (x, y, w, h, speed, type, road, active per car) + # Flags: NUM_FLAGS * 5 = 8 * 5 = 40 (y, road, segment, color, collected per flag) + # Collectibles: MAX_COLLECTIBLES * 5 = 1 * 5 = 5 (y, x, road, type, active per collectible) + # Flags collected mask: NUM_FLAGS = 8 + # Score, lives, is_jumping, jump_cooldown, is_on_steep_road, round_started: 6 + # Total: 8 + 64 + 40 + 5 + 8 + 6 = 131 + self.obs_size = ( + 8 + # player car + self.consts.MAX_ENEMY_CARS * 8 + # enemy cars + self.consts.NUM_FLAGS * 5 + # flags + self.consts.MAX_COLLECTIBLES * 5 + # collectibles + self.consts.NUM_FLAGS + # flags_collected_mask + 6 # score, lives, is_jumping, jump_cooldown, is_on_steep_road, round_started + ) # Speed dividers for movement timing (indexed by speed level) self._speed_dividers = jnp.array([0, 1, 2, 4, 8, 16, 16]) @@ -850,41 +896,41 @@ def check_collision(idx): @partial(jax.jit, static_argnums=(0,)) def _death_step(self, state: UpNDownState) -> UpNDownState: - """Handle player death when on water road (index 2).""" - # Player on water road (index 2 assumed water) + """Handle player death - this is now only used for water crashes during landing. + + When the player dies: + - Lives are decremented + - is_dead is set to True + - awaiting_respawn is set to True + - Player car is moved off-screen (despawned) + - Game waits for player input before respawning + """ + # Skip if already awaiting respawn + already_awaiting = state.awaiting_respawn + + # Player on water road (index 2 assumed water) and not already dead died = jnp.logical_and( - state.player_car.current_road == 2, - ~state.is_dead, + jnp.logical_and( + state.player_car.current_road == 2, + ~state.is_dead, + ), + ~already_awaiting, ) # Use jnp.where for branchless execution lives = jnp.where(died, state.lives - 1, state.lives) - respawn_timer = jnp.where( - died, - jnp.array(self.consts.RESPAWN_DELAY_FRAMES), - jnp.maximum(state.respawn_timer - 1, 0), - ) - is_dead = jnp.logical_and( - jnp.logical_or(state.is_dead, died), - respawn_timer > 0, - ) - - # Respawn player when dead and timer expires - should_respawn = jnp.logical_and(state.is_dead, respawn_timer == 0) - new_position = state.player_car.position._replace( - x=jnp.where(should_respawn, jnp.array(self.consts.RESPAWN_X, dtype=jnp.float32), state.player_car.position.x), - y=jnp.where(should_respawn, jnp.array(self.consts.RESPAWN_Y, dtype=jnp.float32), state.player_car.position.y), - ) + is_dead = jnp.logical_or(state.is_dead, died) + awaiting_respawn = jnp.logical_or(state.awaiting_respawn, died) + + # Stop player movement but keep position (renderer will hide player when awaiting_respawn) player_car = state.player_car._replace( - position=new_position, - speed=jnp.where(should_respawn, 0, state.player_car.speed), - current_road=jnp.where(should_respawn, 0, state.player_car.current_road), + speed=jnp.where(died, 0, state.player_car.speed), ) return state._replace( lives=lives, is_dead=is_dead, - respawn_timer=respawn_timer, + awaiting_respawn=awaiting_respawn, player_car=player_car, ) @@ -1084,9 +1130,22 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: water_crash = jnp.logical_and(is_landing, updated_player_car.current_road == 2) + # On water crash, trigger death state instead of immediate respawn + def trigger_death(s): + # Stop player but keep position (renderer will hide player when awaiting_respawn) + dead_car = s.player_car._replace( + speed=jnp.array(0, dtype=jnp.int32), + ) + return s._replace( + lives=s.lives - 1, + is_dead=jnp.array(True), + awaiting_respawn=jnp.array(True), + player_car=dead_car, + ) + return jax.lax.cond( water_crash, - lambda _: self._respawn_after_collision(next_state, next_state.lives - 1), + lambda _: trigger_death(next_state), lambda _: next_state, operand=None, ) @@ -1204,7 +1263,11 @@ def init_direction(seg, road): @partial(jax.jit, static_argnums=(0,)) def _enemy_step_main(self, state: UpNDownState) -> UpNDownState: """Spawn and move enemy cars with adaptive spawning for consistent enemy presence.""" - rng_key, key_spawn_offset, key_spawn_side, key_spawn_speed, key_spawn_direction, key_spawn_type, key_spawn_sign, key_flip_root = jax.random.split(state.rng_key, 8) + # Split RNG keys - use more splits to ensure better randomization + rng_key, key_spawn_offset, key_spawn_side, key_spawn_speed, key_spawn_direction, key_spawn_type, key_spawn_sign, key_flip_root, key_extra = jax.random.split(state.rng_key, 9) + + # Further split key_spawn_type to get more entropy for type selection + key_spawn_type = jax.random.fold_in(key_spawn_type, state.step_counter) active_mask = state.enemy_cars.active active_count = jnp.sum(active_mask.astype(jnp.int32)) @@ -1417,6 +1480,9 @@ def _respawn_after_collision(self, state: UpNDownState, new_lives: chex.Array) - collectible_spawn_timer=jnp.array(0, dtype=jnp.int32), enemy_cars=enemy_cars, enemy_spawn_timer=jnp.array(self.consts.ENEMY_SPAWN_INTERVAL_BASE, dtype=jnp.int32), + awaiting_respawn=jnp.array(False), + awaiting_round_start=jnp.array(True), # Wait for input to start round after respawn + input_released=jnp.array(False), # Require button release before round can start rng_key=rng_key, ) @@ -1485,9 +1551,18 @@ def handle_late_jump(): return state._replace(score=state.score + bonus, enemy_cars=new_enemy_cars) def handle_ground_collision(): - return self._respawn_after_collision(state, state.lives - 1) + # Trigger death state - stop player but keep position (renderer hides player when awaiting_respawn) + dead_car = state.player_car._replace( + speed=jnp.array(0, dtype=jnp.int32), + ) + return state._replace( + lives=state.lives - 1, + is_dead=jnp.array(True), + awaiting_respawn=jnp.array(True), + player_car=dead_car, + ) - # Ground collision causes respawn (landing is now protected by invincibility) + # Ground collision causes death (landing is now protected by invincibility) any_fatal_collision = grounded_collision return jax.lax.cond( @@ -1582,6 +1657,9 @@ def _reset_jit(self, key: chex.PRNGKey) -> Tuple[UpNDownObservation, UpNDownStat collectible_spawn_timer=jnp.array(0, dtype=jnp.int32), enemy_cars=enemy_cars, enemy_spawn_timer=jnp.array(self.consts.ENEMY_SPAWN_INTERVAL_BASE, dtype=jnp.int32), + awaiting_respawn=jnp.array(False), + awaiting_round_start=jnp.array(True), # Start frozen until first input + input_released=jnp.array(True), # Can start immediately at game start ) initial_obs = self._get_observation(state) return initial_obs, state @@ -1595,15 +1673,62 @@ def reset(self, key=None) -> Tuple[UpNDownObservation, UpNDownState]: @partial(jax.jit, static_argnums=(0,)) def step(self, state: UpNDownState, action: chex.Array) -> Tuple[UpNDownObservation, UpNDownState, float, bool, UpNDownInfo]: previous_state = state - state = self._player_step(state, action) - state = self._death_step(state) - - state = self._passive_score_step_main(state) - state = self._flag_step_main(state) - state = self._completion_bonus_step(state) - state = self._collectible_step_main(state) - state = self._enemy_step_main(state) - state = self._enemy_collision_step_main(state) + + any_action = action != Action.NOOP + + # Track input release - set to True when no button is pressed + input_released = jnp.where(any_action, state.input_released, jnp.array(True)) + state = state._replace(input_released=input_released) + + # Check if we're awaiting respawn - if so, check for input to trigger respawn + should_respawn = jnp.logical_and(state.awaiting_respawn, any_action) + + # Respawn if player pressed any key while awaiting + state = jax.lax.cond( + should_respawn, + lambda s: self._respawn_after_collision(s, s.lives), # lives already decremented + lambda s: s, + state, + ) + + # Check if we're awaiting round start - if so, check for input to start round + # Only start if input was released since respawn (prevents holding button through) + should_start_round = jnp.logical_and( + jnp.logical_and(state.awaiting_round_start, any_action), + state.input_released # Must have released button first + ) + state = jax.lax.cond( + should_start_round, + lambda s: s._replace(awaiting_round_start=jnp.array(False)), + lambda s: s, + state, + ) + + # Skip all game logic if awaiting respawn OR awaiting round start + is_frozen = jnp.logical_or(state.awaiting_respawn, state.awaiting_round_start) + + def run_game_logic(s): + s = self._player_step(s, action) + s = self._death_step(s) + s = self._passive_score_step_main(s) + s = self._flag_step_main(s) + s = self._completion_bonus_step(s) + s = self._collectible_step_main(s) + s = self._enemy_step_main(s) + s = self._enemy_collision_step_main(s) + return s + + def freeze_game(s): + # Only increment step counter while frozen, everything else paused + return s._replace(step_counter=s.step_counter + 1) + + # Run game logic only if not frozen + state = jax.lax.cond( + is_frozen, + freeze_game, + run_game_logic, + state, + ) done = self._get_done(state) env_reward = self._get_reward(previous_state, state) @@ -1618,39 +1743,175 @@ def render(self, state: UpNDownState) -> jnp.ndarray: return jnp.asarray(frame, dtype=jnp.uint8) @partial(jax.jit, static_argnums=(0,)) - def _get_observation(self, state: UpNDownState): - # Clamp to screen-friendly coordinates so observation_space.contains passes - x = jnp.int32(jnp.clip(state.player_car.position.x, 0, 160)) - screen_y = jnp.int32(105) - - player = EntityPosition( - x=x, - y=screen_y, - width=jnp.int32(self.consts.PLAYER_SIZE[0]), - height=jnp.int32(self.consts.PLAYER_SIZE[1]), + def _get_observation(self, state: UpNDownState) -> UpNDownObservation: + """Build complete observation for RL agents. + + Reuses existing game classes directly from state for consistency. + """ + # Check if on steep road + is_on_steep_road = self._is_steep_road_segment( + state.player_car.current_road, + state.player_car.road_index_A, + state.player_car.road_index_B, ) - return UpNDownObservation(player=player) + + return UpNDownObservation( + player_car=state.player_car, + enemy_cars=state.enemy_cars, + flags=state.flags, + collectibles=state.collectibles, + flags_collected_mask=state.flags_collected_mask, + player_score=jnp.int32(state.score), + lives=jnp.int32(state.lives), + is_jumping=jnp.int32(state.is_jumping), + jump_cooldown=jnp.int32(state.jump_cooldown), + is_on_steep_road=jnp.int32(is_on_steep_road), + round_started=jnp.int32(state.round_started), + ) + + @partial(jax.jit, static_argnums=(0,)) + def flatten_car(self, car: Car) -> jnp.ndarray: + """Flatten a Car to a 1D array.""" + return jnp.concatenate([ + jnp.array([car.position.x], dtype=jnp.float32), + jnp.array([car.position.y], dtype=jnp.float32), + jnp.array([car.position.width], dtype=jnp.float32), + jnp.array([car.position.height], dtype=jnp.float32), + jnp.array([car.speed], dtype=jnp.float32), + jnp.array([car.type], dtype=jnp.float32), + jnp.array([car.current_road], dtype=jnp.float32), + jnp.array([car.direction_x], dtype=jnp.float32), + ]) + + @partial(jax.jit, static_argnums=(0,)) + def flatten_enemy_cars(self, enemy_cars: EnemyCars) -> jnp.ndarray: + """Flatten EnemyCars to a 1D array.""" + return jnp.concatenate([ + enemy_cars.position.x, + enemy_cars.position.y, + enemy_cars.position.width, + enemy_cars.position.height, + enemy_cars.speed.astype(jnp.float32), + enemy_cars.type.astype(jnp.float32), + enemy_cars.current_road.astype(jnp.float32), + enemy_cars.active.astype(jnp.float32), + ]) + + @partial(jax.jit, static_argnums=(0,)) + def flatten_flags(self, flags: Flag) -> jnp.ndarray: + """Flatten Flag to a 1D array.""" + return jnp.concatenate([ + flags.y, + flags.road.astype(jnp.float32), + flags.road_segment.astype(jnp.float32), + flags.color_idx.astype(jnp.float32), + flags.collected.astype(jnp.float32), + ]) + + @partial(jax.jit, static_argnums=(0,)) + def flatten_collectibles(self, collectibles: Collectible) -> jnp.ndarray: + """Flatten Collectible to a 1D array.""" + return jnp.concatenate([ + collectibles.y, + collectibles.x, + collectibles.road.astype(jnp.float32), + collectibles.type_id.astype(jnp.float32), + collectibles.active.astype(jnp.float32), + ]) @partial(jax.jit, static_argnums=(0,)) def obs_to_flat_array(self, obs: UpNDownObservation) -> jnp.ndarray: + """Flatten the complete observation to a 1D array for RL. + + Order: + - Player car: 8 values (x, y, w, h, speed, type, road, direction_x) + - Enemy cars: MAX_ENEMY_CARS * 8 values (x, y, w, h, speed, type, road, active per car) + - Flags: NUM_FLAGS * 5 values (y, road, segment, color, collected per flag) + - Collectibles: MAX_COLLECTIBLES * 5 values (y, x, road, type, active per collectible) + - Flags collected mask: NUM_FLAGS values + - Score, lives, is_jumping, jump_cooldown, is_on_steep_road, round_started: 6 values + """ return jnp.concatenate([ - jnp.asarray(obs.player.x, dtype=jnp.int32).reshape(-1), - jnp.asarray(obs.player.y, dtype=jnp.int32).reshape(-1), - jnp.asarray(obs.player.height, dtype=jnp.int32).reshape(-1), - jnp.asarray(obs.player.width, dtype=jnp.int32).reshape(-1), + self.flatten_car(obs.player_car), + self.flatten_enemy_cars(obs.enemy_cars), + self.flatten_flags(obs.flags), + self.flatten_collectibles(obs.collectibles), + obs.flags_collected_mask.flatten().astype(jnp.float32), + jnp.array([obs.player_score], dtype=jnp.float32), + jnp.array([obs.lives], dtype=jnp.float32), + jnp.array([obs.is_jumping], dtype=jnp.float32), + jnp.array([obs.jump_cooldown], dtype=jnp.float32), + jnp.array([obs.is_on_steep_road], dtype=jnp.float32), + jnp.array([obs.round_started], dtype=jnp.float32), ]) def action_space(self) -> spaces.Discrete: return spaces.Discrete(6) - def observation_space(self) -> spaces: + def observation_space(self) -> spaces.Dict: + """Returns the observation space for Up N Down. + + The observation reuses existing game classes: + - player_car: Car with position (x, y, w, h), speed, type, current_road, direction_x + - enemy_cars: EnemyCars with positions, speeds, types, roads, active flags + - flags: Flag with y, road, road_segment, color_idx, collected + - collectibles: Collectible with y, x, road, type_id, active + - flags_collected_mask: boolean array of shape (NUM_FLAGS,) + - player_score: int (0-999999) + - lives: int (0-5) + - is_jumping: int (0 or 1) + - jump_cooldown: int (0-28) + - is_on_steep_road: int (0 or 1) + - round_started: int (0 or 1) + """ return spaces.Dict({ - "player": spaces.Dict({ - "x": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), - "y": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), - "width": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), - "height": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), + "player_car": spaces.Dict({ + "position": spaces.Dict({ + "x": spaces.Box(low=0, high=160, shape=(), dtype=jnp.float32), + "y": spaces.Box(low=-2000, high=0, shape=(), dtype=jnp.float32), + "width": spaces.Box(low=0, high=160, shape=(), dtype=jnp.float32), + "height": spaces.Box(low=0, high=210, shape=(), dtype=jnp.float32), + }), + "speed": spaces.Box(low=-6, high=6, shape=(), dtype=jnp.int32), + "type": spaces.Box(low=0, high=3, shape=(), dtype=jnp.int32), + "current_road": spaces.Box(low=0, high=2, shape=(), dtype=jnp.int32), + "road_index_A": spaces.Box(low=0, high=30, shape=(), dtype=jnp.int32), + "road_index_B": spaces.Box(low=0, high=30, shape=(), dtype=jnp.int32), + "direction_x": spaces.Box(low=-1, high=1, shape=(), dtype=jnp.int32), + }), + "enemy_cars": spaces.Dict({ + "position": spaces.Dict({ + "x": spaces.Box(low=0, high=160, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.float32), + "y": spaces.Box(low=-2000, high=0, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.float32), + "width": spaces.Box(low=0, high=160, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.float32), + "height": spaces.Box(low=0, high=210, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.float32), + }), + "speed": spaces.Box(low=-6, high=6, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), + "type": spaces.Box(low=0, high=3, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), + "current_road": spaces.Box(low=0, high=2, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), + "active": spaces.Box(low=0, high=1, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.bool_), + }), + "flags": spaces.Dict({ + "y": spaces.Box(low=-2000, high=0, shape=(self.consts.NUM_FLAGS,), dtype=jnp.float32), + "road": spaces.Box(low=0, high=1, shape=(self.consts.NUM_FLAGS,), dtype=jnp.int32), + "road_segment": spaces.Box(low=0, high=30, shape=(self.consts.NUM_FLAGS,), dtype=jnp.int32), + "color_idx": spaces.Box(low=0, high=7, shape=(self.consts.NUM_FLAGS,), dtype=jnp.int32), + "collected": spaces.Box(low=0, high=1, shape=(self.consts.NUM_FLAGS,), dtype=jnp.bool_), }), + "collectibles": spaces.Dict({ + "y": spaces.Box(low=-2000, high=0, shape=(self.consts.MAX_COLLECTIBLES,), dtype=jnp.float32), + "x": spaces.Box(low=0, high=160, shape=(self.consts.MAX_COLLECTIBLES,), dtype=jnp.float32), + "road": spaces.Box(low=0, high=1, shape=(self.consts.MAX_COLLECTIBLES,), dtype=jnp.int32), + "type_id": spaces.Box(low=0, high=3, shape=(self.consts.MAX_COLLECTIBLES,), dtype=jnp.int32), + "active": spaces.Box(low=0, high=1, shape=(self.consts.MAX_COLLECTIBLES,), dtype=jnp.bool_), + }), + "flags_collected_mask": spaces.Box(low=0, high=1, shape=(self.consts.NUM_FLAGS,), dtype=jnp.bool_), + "player_score": spaces.Box(low=0, high=999999, shape=(), dtype=jnp.int32), + "lives": spaces.Box(low=0, high=5, shape=(), dtype=jnp.int32), + "is_jumping": spaces.Box(low=0, high=1, shape=(), dtype=jnp.int32), + "jump_cooldown": spaces.Box(low=0, high=28, shape=(), dtype=jnp.int32), + "is_on_steep_road": spaces.Box(low=0, high=1, shape=(), dtype=jnp.int32), + "round_started": spaces.Box(low=0, high=1, shape=(), dtype=jnp.int32), }) def image_space(self) -> spaces.Box: @@ -1662,8 +1923,22 @@ def image_space(self) -> spaces.Box: ) @partial(jax.jit, static_argnums=(0,)) - def _get_info(self, state: UpNDownState, ) -> UpNDownInfo: - return UpNDownInfo(time=jnp.asarray(state.step_counter, dtype=jnp.int32)) + def _get_info(self, state: UpNDownState) -> UpNDownInfo: + """Build info dict with additional debugging/analysis data.""" + # Get current road segment for player + road_index = jnp.where( + state.player_car.current_road == 0, + state.player_car.road_index_A, + state.player_car.road_index_B, + ) + + return UpNDownInfo( + step_counter=jnp.int32(state.step_counter), + difficulty=jnp.int32(state.difficulty), + movement_steps=jnp.int32(state.movement_steps), + jump_slope=jnp.float32(state.jump_slope), + player_road_segment=jnp.int32(road_index), + ) @partial(jax.jit, static_argnums=(0,)) def _get_reward(self, previous_state: UpNDownState, state: UpNDownState): @@ -1707,11 +1982,10 @@ def _createBackgroundSprite(self, dimensions: Tuple[int, int]) -> jnp.ndarray: temp_pointer = self._createBackgroundSprite((1, 1)) blackout_square = self._createBackgroundSprite(self.consts.FLAG_BLACKOUT_SIZE) - # 2. Update asset config to include both walls + # Build asset config locally (matches other games' pattern) asset_config, road_files = self._get_asset_config(background, top_block, bottom_block, temp_pointer, blackout_square) sprite_path = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/" - # 3. Make a single call to the setup function ( self.PALETTE, self.SHAPE_MASKS, @@ -1800,6 +2074,37 @@ def _get_road_sprite_sizes(self, road_files: list[str]) -> list: sizes.append(sprite.shape[0]) complete_size = int(sum(sizes)) return sizes, complete_size + + def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.ndarray, bottomBlockSprite: jnp.ndarray, tempPointer: jnp.ndarray, blackoutSquare: jnp.ndarray) -> tuple[list, list[str]]: + """Return asset manifest and ordered road files (renderer-local like other games).""" + road_dir = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/roads" + road_files = sorted( + file for file in os.listdir(road_dir) + if file.endswith(".npy") + ) + roads = [f"roads/{file}" for file in road_files] + return [ + {'name': 'background', 'type': 'background', 'data': backgroundSprite}, + {'name': 'road', 'type': 'group', 'files': roads}, + {'name': 'player', 'type': 'single', 'file': 'player_car.npy'}, + {'name': 'camero_left', 'type': 'single', 'file': 'enemy_cars/camero_left.npy'}, + {'name': 'flag_carrier_left', 'type': 'single', 'file': 'enemy_cars/flag_carrier_left.npy'}, + {'name': 'pick_up_truck_left', 'type': 'single', 'file': 'enemy_cars/pick_up_truck_left.npy'}, + {'name': 'truck_left', 'type': 'single', 'file': 'enemy_cars/truck_left.npy'}, + {'name': 'wall_top', 'type': 'procedural', 'data': topBlockSprite}, + {'name': 'wall_bottom', 'type': 'procedural', 'data': bottomBlockSprite}, + {'name': 'all_flags_top', 'type': 'single', 'file': 'all_flags_top.npy'}, + {'name': 'all_lives_bottom', 'type': 'single', 'file': 'all_lives_bottom.npy'}, + {'name': 'score_digits', 'type': 'digits', 'pattern': 'score/score_{}.npy'}, + {'name': 'pink_flag', 'type': 'single', 'file': 'pink_flag.npy'}, + {'name': 'flag_pole', 'type': 'single', 'file': 'flag_pole.npy'}, + {'name': 'cherry', 'type': 'single', 'file': 'cherry.npy'}, + {'name': 'balloon', 'type': 'single', 'file': 'balloon.npy'}, + {'name': 'lollypop', 'type': 'single', 'file': 'lollypop.npy'}, + {'name': 'ice_cream', 'type': 'single', 'file': 'ice_cream_cone.npy'}, + {'name': 'tempPointer', 'type': 'procedural', 'data': tempPointer}, + {'name': 'blackout_square', 'type': 'procedural', 'data': blackoutSquare}, + ], roads def _get_x_on_road(self, y: chex.Array, road_segment: chex.Array, track_corners_x: chex.Array) -> chex.Array: """Calculate the X position on a road given a Y coordinate and road segment.""" @@ -1832,38 +2137,6 @@ def _jump_arc_offset(self, jump_cooldown: chex.Array) -> chex.Array: centered = (progress - 0.5) * 2.0 return self.consts.JUMP_ARC_HEIGHT * (1.0 - centered * centered) - def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.ndarray, bottomBlockSprite: jnp.ndarray, tempPointer: jnp.ndarray, blackoutSquare: jnp.ndarray) -> tuple[list, list[str]]: - """Returns the asset manifest and ordered road files.""" - road_dir = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/roads" - road_files = sorted( - file for file in os.listdir(road_dir) - if file.endswith(".npy") - ) - roads = [f"roads/{file}" for file in road_files] - return [ - {'name': 'background', 'type': 'background', 'data': backgroundSprite}, - {'name': 'road', 'type': 'group', 'files': roads}, - {'name': 'player', 'type': 'single', 'file': 'player_car.npy'}, - # Only load left-facing enemy sprites; right-facing are created by flipping - {'name': 'camero_left', 'type': 'single', 'file': 'enemy_cars/camero_left.npy'}, - {'name': 'flag_carrier_left', 'type': 'single', 'file': 'enemy_cars/flag_carrier_left.npy'}, - {'name': 'pick_up_truck_left', 'type': 'single', 'file': 'enemy_cars/pick_up_truck_left.npy'}, - {'name': 'truck_left', 'type': 'single', 'file': 'enemy_cars/truck_left.npy'}, - {'name': 'wall_top', 'type': 'procedural', 'data': topBlockSprite}, - {'name': 'wall_bottom', 'type': 'procedural', 'data': bottomBlockSprite}, - {'name': 'all_flags_top', 'type': 'single', 'file': 'all_flags_top.npy'}, - {'name': 'all_lives_bottom', 'type': 'single', 'file': 'all_lives_bottom.npy'}, - {'name': 'score_digits', 'type': 'digits', 'pattern': 'score/score_{}.npy'}, - {'name': 'pink_flag', 'type': 'single', 'file': 'pink_flag.npy'}, - {'name': 'flag_pole', 'type': 'single', 'file': 'flag_pole.npy'}, - {'name': 'cherry', 'type': 'single', 'file': 'cherry.npy'}, - {'name': 'balloon', 'type': 'single', 'file': 'balloon.npy'}, - {'name': 'lollypop', 'type': 'single', 'file': 'lollypop.npy'}, - {'name': 'ice_cream', 'type': 'single', 'file': 'ice_cream_cone.npy'}, - {'name': 'tempPointer', 'type': 'procedural', 'data': tempPointer}, - {'name': 'blackout_square', 'type': 'procedural', 'data': blackoutSquare}, - ], roads - @partial(jax.jit, static_argnums=(0,)) def render(self, state): raster = self.jr.create_object_raster(self.BACKGROUND) @@ -1928,7 +2201,12 @@ def render_enemy(carry, enemy_idx): enemy_type = state.enemy_cars.type[enemy_idx] direction_x = state.enemy_cars.direction_x[enemy_idx] screen_y = 105 + (enemy_y - state.player_car.position.y) - is_visible = jnp.logical_and(enemy_active, jnp.logical_and(screen_y > 25, screen_y < 195)) + # Hide enemies when awaiting round start or awaiting respawn + should_hide = jnp.logical_or(state.awaiting_round_start, state.awaiting_respawn) + is_visible = jnp.logical_and( + jnp.logical_and(enemy_active, jnp.logical_and(screen_y > 25, screen_y < 195)), + ~should_hide + ) enemy_mask = select_enemy_mask(enemy_type, direction_x < 0) raster = jax.lax.cond( @@ -1950,7 +2228,14 @@ def render_enemy(carry, enemy_idx): player_screen_y = jnp.int32(105 - jump_offset) player_mask = self.SHAPE_MASKS["player"] - raster_player = self.jr.render_at_clipped(raster_enemies, state.player_car.position.x, player_screen_y, player_mask) + # Skip rendering player when awaiting respawn OR awaiting round start + should_hide_player = jnp.logical_or(state.awaiting_respawn, state.awaiting_round_start) + raster_player = jax.lax.cond( + should_hide_player, + lambda _: raster_enemies, # Don't render player + lambda _: self.jr.render_at_clipped(raster_enemies, state.player_car.position.x, player_screen_y, player_mask), + operand=None, + ) wall_top_mask = self.SHAPE_MASKS["wall_top"] raster_wall_top = self.jr.render_at(raster_player, 0, 0, wall_top_mask) @@ -2002,9 +2287,11 @@ def render_flag(carry, flag_idx): operand=None, ) screen_y = 105 + (flag_y - state.player_car.position.y) + # Hide flags when awaiting round start or awaiting respawn + should_hide = jnp.logical_or(state.awaiting_round_start, state.awaiting_respawn) is_visible = jnp.logical_and( jnp.logical_and(screen_y > 25, screen_y < 195), - ~flag_collected + jnp.logical_and(~flag_collected, ~should_hide) ) color_id = self.flag_palette_ids[flag_color_idx] colored_flag_mask = jnp.where( @@ -2050,9 +2337,11 @@ def render_collectible(carry, collectible_idx): collectible_color_idx = state.collectibles.color_idx[collectible_idx] collectible_type_id = state.collectibles.type_id[collectible_idx] screen_y = 105 + (collectible_y - state.player_car.position.y) + # Hide collectibles when awaiting round start or awaiting respawn + should_hide = jnp.logical_or(state.awaiting_round_start, state.awaiting_respawn) is_visible = jnp.logical_and( jnp.logical_and(screen_y > 25, screen_y < 195), - collectible_active + jnp.logical_and(collectible_active, ~should_hide) ) def get_sprite_and_mask(type_id): From 7add44ec4fcfea23c3b043c455180e8641818e4c Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sun, 21 Dec 2025 19:37:42 +0100 Subject: [PATCH 68/76] fix observation logic for tests --- src/jaxatari/games/jax_upndown.py | 168 +++++++++++++++--------------- 1 file changed, 85 insertions(+), 83 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index ca04cba7c..7bf6abcab 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -188,28 +188,18 @@ class UpNDownState(NamedTuple): - class UpNDownObservation(NamedTuple): - """Complete observation for RL agents in Up N Down. - - Reuses existing game classes for consistency: - - player_car: Car with EntityPosition, speed, type, road info - - enemy_cars: EnemyCars pool with positions, speeds, types, active flags - - flags: Flag with y, road, segment, color, collected status - - collectibles: Collectible with positions, types, active status - - Additional game state: score, lives, jumping status, etc. - """ - player_car: Car # Reuse existing Car class - enemy_cars: EnemyCars # Reuse existing EnemyCars class - flags: Flag # Reuse existing Flag class - collectibles: Collectible # Reuse existing Collectible class - flags_collected_mask: jnp.ndarray # Shape (NUM_FLAGS,) - boolean mask - player_score: jnp.ndarray - lives: jnp.ndarray - is_jumping: jnp.ndarray # Whether player is currently jumping - jump_cooldown: jnp.ndarray # Frames remaining in jump - is_on_steep_road: jnp.ndarray # Whether currently on steep section - round_started: jnp.ndarray # Whether player has started moving + player_car: Car + enemy_cars: EnemyCars + flags: Flag + collectibles: Collectible + flags_collected_mask: chex.Array # Shape (NUM_FLAGS,) - int32 (0 or 1) + player_score: chex.Array + lives: chex.Array + is_jumping: chex.Array + jump_cooldown: chex.Array + is_on_steep_road: chex.Array + round_started: chex.Array class UpNDownInfo(NamedTuple): @@ -236,18 +226,18 @@ def __init__(self, consts: UpNDownConstants = None, reward_funcs: list[callable] Action.DOWNFIRE, ] # Calculate obs_size based on observation structure: - # Player car: 8 values (x, y, w, h, speed, type, road, direction_x) - # Enemy cars: MAX_ENEMY_CARS * 8 = 8 * 8 = 64 (x, y, w, h, speed, type, road, active per car) + # Player car: 10 values (x, y, w, h, speed, type, road, road_index_A, road_index_B, direction_x) + # Enemy cars: MAX_ENEMY_CARS * 12 = 8 * 12 = 96 (x, y, w, h, speed, type, road, road_index_A, road_index_B, direction_x, active, age) # Flags: NUM_FLAGS * 5 = 8 * 5 = 40 (y, road, segment, color, collected per flag) - # Collectibles: MAX_COLLECTIBLES * 5 = 1 * 5 = 5 (y, x, road, type, active per collectible) + # Collectibles: MAX_COLLECTIBLES * 6 = 1 * 6 = 6 (y, x, road, color_idx, type, active per collectible) # Flags collected mask: NUM_FLAGS = 8 # Score, lives, is_jumping, jump_cooldown, is_on_steep_road, round_started: 6 - # Total: 8 + 64 + 40 + 5 + 8 + 6 = 131 + # Total: 10 + 96 + 40 + 6 + 8 + 6 = 166 self.obs_size = ( - 8 + # player car - self.consts.MAX_ENEMY_CARS * 8 + # enemy cars + 10 + # player car + self.consts.MAX_ENEMY_CARS * 12 + # enemy cars (all fields) self.consts.NUM_FLAGS * 5 + # flags - self.consts.MAX_COLLECTIBLES * 5 + # collectibles + self.consts.MAX_COLLECTIBLES * 6 + # collectibles (all fields) self.consts.NUM_FLAGS + # flags_collected_mask 6 # score, lives, is_jumping, jump_cooldown, is_on_steep_road, round_started ) @@ -1746,7 +1736,7 @@ def render(self, state: UpNDownState) -> jnp.ndarray: def _get_observation(self, state: UpNDownState) -> UpNDownObservation: """Build complete observation for RL agents. - Reuses existing game classes directly from state for consistency. + Reuses existing game classes directly. Extra fields are filtered during flatten. """ # Check if on steep road is_on_steep_road = self._is_steep_road_segment( @@ -1760,7 +1750,7 @@ def _get_observation(self, state: UpNDownState) -> UpNDownObservation: enemy_cars=state.enemy_cars, flags=state.flags, collectibles=state.collectibles, - flags_collected_mask=state.flags_collected_mask, + flags_collected_mask=state.flags_collected_mask.astype(jnp.int32), player_score=jnp.int32(state.score), lives=jnp.int32(state.lives), is_jumping=jnp.int32(state.is_jumping), @@ -1773,50 +1763,57 @@ def _get_observation(self, state: UpNDownState) -> UpNDownObservation: def flatten_car(self, car: Car) -> jnp.ndarray: """Flatten a Car to a 1D array.""" return jnp.concatenate([ - jnp.array([car.position.x], dtype=jnp.float32), - jnp.array([car.position.y], dtype=jnp.float32), - jnp.array([car.position.width], dtype=jnp.float32), - jnp.array([car.position.height], dtype=jnp.float32), - jnp.array([car.speed], dtype=jnp.float32), - jnp.array([car.type], dtype=jnp.float32), - jnp.array([car.current_road], dtype=jnp.float32), - jnp.array([car.direction_x], dtype=jnp.float32), + jnp.array([car.position.x], dtype=jnp.int32), + jnp.array([car.position.y], dtype=jnp.int32), + jnp.array([car.position.width], dtype=jnp.int32), + jnp.array([car.position.height], dtype=jnp.int32), + jnp.array([car.speed], dtype=jnp.int32), + jnp.array([car.type], dtype=jnp.int32), + jnp.array([car.current_road], dtype=jnp.int32), + jnp.array([car.road_index_A], dtype=jnp.int32), + jnp.array([car.road_index_B], dtype=jnp.int32), + jnp.array([car.direction_x], dtype=jnp.int32), ]) @partial(jax.jit, static_argnums=(0,)) def flatten_enemy_cars(self, enemy_cars: EnemyCars) -> jnp.ndarray: - """Flatten EnemyCars to a 1D array.""" + """Flatten EnemyCars to a 1D array (all fields).""" return jnp.concatenate([ - enemy_cars.position.x, - enemy_cars.position.y, - enemy_cars.position.width, - enemy_cars.position.height, - enemy_cars.speed.astype(jnp.float32), - enemy_cars.type.astype(jnp.float32), - enemy_cars.current_road.astype(jnp.float32), - enemy_cars.active.astype(jnp.float32), + enemy_cars.position.x.astype(jnp.int32), + enemy_cars.position.y.astype(jnp.int32), + enemy_cars.position.width.astype(jnp.int32), + enemy_cars.position.height.astype(jnp.int32), + enemy_cars.speed.astype(jnp.int32), + enemy_cars.type.astype(jnp.int32), + enemy_cars.current_road.astype(jnp.int32), + enemy_cars.road_index_A.astype(jnp.int32), + enemy_cars.road_index_B.astype(jnp.int32), + enemy_cars.direction_x.astype(jnp.int32), + enemy_cars.active.astype(jnp.int32), + enemy_cars.age.astype(jnp.int32), ]) @partial(jax.jit, static_argnums=(0,)) def flatten_flags(self, flags: Flag) -> jnp.ndarray: """Flatten Flag to a 1D array.""" return jnp.concatenate([ - flags.y, - flags.road.astype(jnp.float32), - flags.road_segment.astype(jnp.float32), - flags.color_idx.astype(jnp.float32), - flags.collected.astype(jnp.float32), + flags.y.astype(jnp.int32), + flags.road.astype(jnp.int32), + flags.road_segment.astype(jnp.int32), + flags.color_idx.astype(jnp.int32), + flags.collected.astype(jnp.int32), ]) @partial(jax.jit, static_argnums=(0,)) def flatten_collectibles(self, collectibles: Collectible) -> jnp.ndarray: - """Flatten Collectible to a 1D array.""" + """Flatten Collectible to a 1D array (all fields).""" return jnp.concatenate([ - collectibles.y, - collectibles.x, - collectibles.road.astype(jnp.float32), - collectibles.type_id.astype(jnp.float32), - collectibles.active.astype(jnp.float32), + collectibles.y.astype(jnp.int32), + collectibles.x.astype(jnp.int32), + collectibles.road.astype(jnp.int32), + collectibles.color_idx.astype(jnp.int32), + collectibles.type_id.astype(jnp.int32), + collectibles.active.astype(jnp.int32), ]) @partial(jax.jit, static_argnums=(0,)) @@ -1824,10 +1821,10 @@ def obs_to_flat_array(self, obs: UpNDownObservation) -> jnp.ndarray: """Flatten the complete observation to a 1D array for RL. Order: - - Player car: 8 values (x, y, w, h, speed, type, road, direction_x) - - Enemy cars: MAX_ENEMY_CARS * 8 values (x, y, w, h, speed, type, road, active per car) + - Player car: 10 values (x, y, w, h, speed, type, road, road_index_A, road_index_B, direction_x) + - Enemy cars: MAX_ENEMY_CARS * 12 values (x, y, w, h, speed, type, road, road_index_A, road_index_B, direction_x, active, age) - Flags: NUM_FLAGS * 5 values (y, road, segment, color, collected per flag) - - Collectibles: MAX_COLLECTIBLES * 5 values (y, x, road, type, active per collectible) + - Collectibles: MAX_COLLECTIBLES * 6 values (y, x, road, color_idx, type, active per collectible) - Flags collected mask: NUM_FLAGS values - Score, lives, is_jumping, jump_cooldown, is_on_steep_road, round_started: 6 values """ @@ -1836,13 +1833,13 @@ def obs_to_flat_array(self, obs: UpNDownObservation) -> jnp.ndarray: self.flatten_enemy_cars(obs.enemy_cars), self.flatten_flags(obs.flags), self.flatten_collectibles(obs.collectibles), - obs.flags_collected_mask.flatten().astype(jnp.float32), - jnp.array([obs.player_score], dtype=jnp.float32), - jnp.array([obs.lives], dtype=jnp.float32), - jnp.array([obs.is_jumping], dtype=jnp.float32), - jnp.array([obs.jump_cooldown], dtype=jnp.float32), - jnp.array([obs.is_on_steep_road], dtype=jnp.float32), - jnp.array([obs.round_started], dtype=jnp.float32), + obs.flags_collected_mask.flatten().astype(jnp.int32), + jnp.array([obs.player_score], dtype=jnp.int32), + jnp.array([obs.lives], dtype=jnp.int32), + jnp.array([obs.is_jumping], dtype=jnp.int32), + jnp.array([obs.jump_cooldown], dtype=jnp.int32), + jnp.array([obs.is_on_steep_road], dtype=jnp.int32), + jnp.array([obs.round_started], dtype=jnp.int32), ]) def action_space(self) -> spaces.Discrete: @@ -1867,10 +1864,10 @@ def observation_space(self) -> spaces.Dict: return spaces.Dict({ "player_car": spaces.Dict({ "position": spaces.Dict({ - "x": spaces.Box(low=0, high=160, shape=(), dtype=jnp.float32), - "y": spaces.Box(low=-2000, high=0, shape=(), dtype=jnp.float32), - "width": spaces.Box(low=0, high=160, shape=(), dtype=jnp.float32), - "height": spaces.Box(low=0, high=210, shape=(), dtype=jnp.float32), + "x": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), + "y": spaces.Box(low=-2000, high=0, shape=(), dtype=jnp.int32), + "width": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), + "height": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), }), "speed": spaces.Box(low=-6, high=6, shape=(), dtype=jnp.int32), "type": spaces.Box(low=0, high=3, shape=(), dtype=jnp.int32), @@ -1881,31 +1878,36 @@ def observation_space(self) -> spaces.Dict: }), "enemy_cars": spaces.Dict({ "position": spaces.Dict({ - "x": spaces.Box(low=0, high=160, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.float32), - "y": spaces.Box(low=-2000, high=0, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.float32), - "width": spaces.Box(low=0, high=160, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.float32), - "height": spaces.Box(low=0, high=210, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.float32), + "x": spaces.Box(low=0, high=160, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), + "y": spaces.Box(low=-2000, high=0, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), + "width": spaces.Box(low=0, high=160, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), + "height": spaces.Box(low=0, high=210, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), }), "speed": spaces.Box(low=-6, high=6, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), "type": spaces.Box(low=0, high=3, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), "current_road": spaces.Box(low=0, high=2, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), - "active": spaces.Box(low=0, high=1, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.bool_), + "road_index_A": spaces.Box(low=0, high=30, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), + "road_index_B": spaces.Box(low=0, high=30, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), + "direction_x": spaces.Box(low=-1, high=1, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), + "active": spaces.Box(low=0, high=1, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), + "age": spaces.Box(low=0, high=10000, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), }), "flags": spaces.Dict({ - "y": spaces.Box(low=-2000, high=0, shape=(self.consts.NUM_FLAGS,), dtype=jnp.float32), + "y": spaces.Box(low=-2000, high=0, shape=(self.consts.NUM_FLAGS,), dtype=jnp.int32), "road": spaces.Box(low=0, high=1, shape=(self.consts.NUM_FLAGS,), dtype=jnp.int32), "road_segment": spaces.Box(low=0, high=30, shape=(self.consts.NUM_FLAGS,), dtype=jnp.int32), "color_idx": spaces.Box(low=0, high=7, shape=(self.consts.NUM_FLAGS,), dtype=jnp.int32), - "collected": spaces.Box(low=0, high=1, shape=(self.consts.NUM_FLAGS,), dtype=jnp.bool_), + "collected": spaces.Box(low=0, high=1, shape=(self.consts.NUM_FLAGS,), dtype=jnp.int32), }), "collectibles": spaces.Dict({ - "y": spaces.Box(low=-2000, high=0, shape=(self.consts.MAX_COLLECTIBLES,), dtype=jnp.float32), - "x": spaces.Box(low=0, high=160, shape=(self.consts.MAX_COLLECTIBLES,), dtype=jnp.float32), + "y": spaces.Box(low=-2000, high=0, shape=(self.consts.MAX_COLLECTIBLES,), dtype=jnp.int32), + "x": spaces.Box(low=0, high=160, shape=(self.consts.MAX_COLLECTIBLES,), dtype=jnp.int32), "road": spaces.Box(low=0, high=1, shape=(self.consts.MAX_COLLECTIBLES,), dtype=jnp.int32), + "color_idx": spaces.Box(low=0, high=7, shape=(self.consts.MAX_COLLECTIBLES,), dtype=jnp.int32), "type_id": spaces.Box(low=0, high=3, shape=(self.consts.MAX_COLLECTIBLES,), dtype=jnp.int32), - "active": spaces.Box(low=0, high=1, shape=(self.consts.MAX_COLLECTIBLES,), dtype=jnp.bool_), + "active": spaces.Box(low=0, high=1, shape=(self.consts.MAX_COLLECTIBLES,), dtype=jnp.int32), }), - "flags_collected_mask": spaces.Box(low=0, high=1, shape=(self.consts.NUM_FLAGS,), dtype=jnp.bool_), + "flags_collected_mask": spaces.Box(low=0, high=1, shape=(self.consts.NUM_FLAGS,), dtype=jnp.int32), "player_score": spaces.Box(low=0, high=999999, shape=(), dtype=jnp.int32), "lives": spaces.Box(low=0, high=5, shape=(), dtype=jnp.int32), "is_jumping": spaces.Box(low=0, high=1, shape=(), dtype=jnp.int32), From a962b54661913c922052f173afe28ed1b8ea08e3 Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sat, 31 Jan 2026 17:14:14 +0100 Subject: [PATCH 69/76] try to implement some of the feedback --- src/jaxatari/games/jax_upndown.py | 291 ++++++++++++++++++++++-------- 1 file changed, 216 insertions(+), 75 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 7bf6abcab..5b903ab0b 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -15,9 +15,9 @@ class UpNDownConstants(NamedTuple): FRAME_SKIP: int = 4 DIFFICULTIES: chex.Array = jnp.array([0, 1, 2, 3, 4, 5]) - MAX_SPEED: int = 6 + MAX_SPEED: int = 7 INITIAL_LIVES: int = 5 - JUMP_ARC_HEIGHT: float = 18.0 + JUMP_ARC_HEIGHT: float = 22.0 RESPAWN_DELAY_FRAMES: int = 60 RESPAWN_Y: int = 0 RESPAWN_X: int = 30 @@ -58,6 +58,8 @@ class UpNDownConstants(NamedTuple): PASSIVE_SCORE_INTERVAL: int = 60 # Steps between passive score awards PASSIVE_SCORE_AMOUNT: int = 10 # Points awarded for passive scoring COLLISION_THRESHOLD: float = 5.0 # Distance threshold for flag/collectible collision + ACCELERATION_INTERVAL: int = 6 # Frames between speed changes when holding up/down + EXTRA_LIFE_THRESHOLD: int = 10000 # Score threshold for extra life TRACK_LENGTH: int = 1036 FIRST_TRACK_CORNERS_X: chex.Array = jnp.array([30, 75, 128, 75, 21, 75, 131, 111, 150, 95, 150, 115, 150, 108, 150, 115, 115, 75, 18, 38, 67, 38, 38, 20, 64, 30]) TRACK_CORNERS_Y: chex.Array = jnp.array([0, -40, -98, -155, -203, -268, -327, -347, -382, -467, -525, -565, -597, -625, -670, -705, -738, -788, -838, -862, -898, -925, -950, -972, -1000, -1035]) @@ -185,6 +187,9 @@ class UpNDownState(NamedTuple): awaiting_round_start: chex.Array # True at game start and after respawn until input received # Input debounce - requires button release before next input triggers round start input_released: chex.Array # True when player has released all buttons since last state change + jump_key_released: chex.Array # True if jump button was NOT pressed in previous step + last_extra_life_score: chex.Array # Score at which last extra life was awarded + jump_total_duration: chex.Array # Total duration of the current/last jump for rendering arc @@ -242,7 +247,7 @@ def __init__(self, consts: UpNDownConstants = None, reward_funcs: list[callable] 6 # score, lives, is_jumping, jump_cooldown, is_on_steep_road, round_started ) # Speed dividers for movement timing (indexed by speed level) - self._speed_dividers = jnp.array([0, 1, 2, 4, 8, 16, 16]) + self._speed_dividers = jnp.array([0, 1, 2, 4, 8, 16, 16, 16, 16]) @partial(jax.jit, static_argnums=(0,)) def _compute_movement_timing(self, speed: chex.Array, step_counter: chex.Array) -> Tuple[chex.Array, chex.Array, chex.Array, chex.Array]: @@ -261,7 +266,7 @@ def _compute_movement_timing(self, speed: chex.Array, step_counter: chex.Array) move_y = jnp.logical_and((step_counter % period) == (half_period % period), speed != 0) move_x = jnp.logical_and((step_counter % period) == 0, speed != 0) - step_size = jnp.where(speed_index >= self.consts.MAX_SPEED, 1.5, 1.0) + step_size = jnp.where(speed_index >= 6, 1.5 + (speed_index - 6) * 0.2, 1.0) return move_y, move_x, step_size, speed_sign def _apply_steep_road_penalty( @@ -547,6 +552,7 @@ def _advance_player_car( car_type: chex.Array, is_landing: chex.Array, stored_jump_slope: chex.Array, + jump_progress: chex.Array, ) -> Car: """ Advance the player car position. @@ -574,16 +580,59 @@ def _advance_player_car( position, slope, b, speed_sign, step_size, car_direction_x, move_y, move_x ) + # === JUMP PHYSICS NORMALIZATION === + # Normalize jump velocity so total speed (Euclidean) matches 'step_size' + # Without this, diagonal jumps cover more distance per frame than straight road movement + # stored_jump_slope is dX/dY + # Scaling factor = 1 / sqrt(1 + slope^2) + jump_speed_scaling = 1.0 / jnp.sqrt(1.0 + stored_jump_slope**2) + jump_step_size = step_size * jump_speed_scaling + # === Y MOVEMENT === - # When jumping: move freely in Y direction + # When jumping: move freely in Y direction but with normalized speed # When on road: use road-based movement result - jump_y = jnp.where(move_y, position_y + speed_sign * -step_size, position_y) + # Note: We must apply step_y on move_y ticks to keep sync with engine heartbeat + jump_y = jnp.where(move_y, position_y + speed_sign * -jump_step_size, position_y) new_player_y = jnp.where(is_jumping, jump_y, road_y) # === X MOVEMENT === # When jumping: use stored_jump_slope (locked at jump start) - moves X proportionally to Y - # When on road: use road-based movement result - jump_x = jnp.where(move_x, position_x - speed_sign * stored_jump_slope * step_size, position_x) + # Use jump_step_size to maintain correct trajectory and speed + # X step = slope * Y step magnitude = slope * jump_step_size + raw_jump_x = jnp.where(move_x, position_x - speed_sign * stored_jump_slope * jump_step_size, position_x) + + # === AIR STEERING / MAGNETISM === + # Gradually steer towards the nearest road while in the air to prevent "teleporting" on landing + segment_curr = self._get_road_segment(new_player_y) + road_A_x_curr = self._get_x_on_road(new_player_y, segment_curr, self.consts.FIRST_TRACK_CORNERS_X) + road_B_x_curr = self._get_x_on_road(new_player_y, segment_curr, self.consts.SECOND_TRACK_CORNERS_X) + + dist_A = jnp.abs(raw_jump_x - road_A_x_curr) + dist_B = jnp.abs(raw_jump_x - road_B_x_curr) + + # Find closest road center + target_road_x = jnp.where(dist_A < dist_B, road_A_x_curr, road_B_x_curr) + dist_to_target = target_road_x - raw_jump_x + + # Only nudge in the last 25% of the jump (progress > 0.75) + # when reasonably close to a road (within 2x tolerance) + # and only when player is between the two roads + + is_late_jump = jump_progress > 0.75 + is_reasonably_close = jnp.abs(dist_to_target) < (self.consts.LANDING_TOLERANCE * 2.0) + + # Check if player is between the two roads + min_road_x_curr = jnp.minimum(road_A_x_curr, road_B_x_curr) + max_road_x_curr = jnp.maximum(road_A_x_curr, road_B_x_curr) + is_between_roads = jnp.logical_and(raw_jump_x > min_road_x_curr, raw_jump_x < max_road_x_curr) + + should_magnet = jnp.logical_and(is_late_jump, jnp.logical_and(is_reasonably_close, is_between_roads)) + + # Nudge factor: reduced to 2% steering strength (very subtle) + nudge_amount = dist_to_target * 0.08 + + jump_x = raw_jump_x + jnp.where(should_magnet, nudge_amount, 0.0) + new_player_x = jnp.where(is_jumping, jump_x, road_x) # === LANDING LOGIC === @@ -617,12 +666,41 @@ def _advance_player_car( # Valid landing: on a road OR between roads (will snap to nearest) valid_landing = jnp.logical_or(on_any_road, between_roads) + # Bridge crossing physics: if speed is high, we can "skip" small water gaps (land on nearest road) + # In original game, bridges allow crossing without jumping if you have speed + can_bridge_gap = jnp.abs(speed) >= 5 + # If landing and between roads but not directly on a road, snap to nearest road should_snap = jnp.logical_and(is_landing, jnp.logical_and(between_roads, jnp.logical_not(on_any_road))) - final_player_x = jnp.where(should_snap, nearest_road_x, new_player_x) + # Also snap if we are "in water" but have speed to bridge the gap + should_snap_bridge = jnp.logical_and(is_landing, jnp.logical_and(can_bridge_gap, jnp.logical_not(valid_landing))) + + final_player_x = jnp.where(jnp.logical_or(should_snap, should_snap_bridge), nearest_road_x, new_player_x) - # Water landing (crash): landing outside the valid road area - landing_in_water = jnp.logical_and(is_landing, jnp.logical_not(valid_landing)) + # Water landing (crash): Only if NOT on road AND NOT between roads (i.e., landed completely outside) + # User clarification: "crashing should only be possible if you dont land in betweeen or on the roads" + + # Safe if: ON ROAD or BETWEEN ROADS + is_safe_landing = jnp.logical_or(on_any_road, between_roads) + + landing_in_water = jnp.logical_and( + is_landing, + jnp.logical_not(is_safe_landing) + ) + + # Snap logic: + # If landing BETWEEN roads but not ON a road -> snap to nearest (safe!) + # (Outside landings are now crashes, so no need to snap them) + should_snap = jnp.logical_and(is_landing, jnp.logical_and(between_roads, jnp.logical_not(on_any_road))) + + # Also snap if bridging (fast jump across water gap) + should_snap_bridge = jnp.logical_and(is_landing, jnp.logical_and(between_roads, can_bridge_gap)) + + final_player_x = jnp.where( + jnp.logical_or(should_snap, should_snap_bridge), + nearest_road_x, + new_player_x + ) # === UPDATE ROAD STATE === # Determine which road to assign on landing (priority: road A > road B > nearest) @@ -712,44 +790,23 @@ def _advance_car_core( @partial(jax.jit, static_argnums=(0,)) def _flag_step(self, state: UpNDownState, new_player_y: chex.Array, player_x: chex.Array, current_road: chex.Array) -> Tuple[Flag, chex.Array, chex.Array]: - """Update flag collection state and score. + """Update flag collection state and score (vectorized).""" + # Calculate flag X positions on both roads + # _get_x_on_road supports array inputs via advanced indexing + x_road_0 = self._get_x_on_road(state.flags.y, state.flags.road_segment, self.consts.FIRST_TRACK_CORNERS_X) + x_road_1 = self._get_x_on_road(state.flags.y, state.flags.road_segment, self.consts.SECOND_TRACK_CORNERS_X) - Args: - state: Current game state - new_player_y: Updated player Y position after movement - player_x: Current player X position - current_road: Current road player is on - - Returns: - Tuple of (updated_flags, score_delta, flags_collected_mask) - """ - # Check collision for each flag - def check_flag_collision(flag_idx): - flag_y = state.flags.y[flag_idx] - flag_road = state.flags.road[flag_idx] - flag_collected = state.flags.collected[flag_idx] - - # Calculate flag X position on its road - flag_segment = state.flags.road_segment[flag_idx] - flag_x = jax.lax.cond( - flag_road == 0, - lambda _: self._get_x_on_road(flag_y, flag_segment, self.consts.FIRST_TRACK_CORNERS_X), - lambda _: self._get_x_on_road(flag_y, flag_segment, self.consts.SECOND_TRACK_CORNERS_X), - operand=None, - ) - - # Check if player is close enough to collect the flag - y_distance = jnp.abs(new_player_y - flag_y) - x_distance = jnp.abs(player_x - flag_x) - same_road = (current_road == flag_road) - - collision = jnp.logical_and( - jnp.logical_and(y_distance < self.consts.COLLISION_THRESHOLD, x_distance < self.consts.COLLISION_THRESHOLD), - jnp.logical_and(same_road, ~flag_collected) - ) - return collision + flag_x = jnp.where(state.flags.road == 0, x_road_0, x_road_1) - new_collections = jax.vmap(check_flag_collision)(jnp.arange(self.consts.NUM_FLAGS)) + # Vectorized distance check + y_dist = jnp.abs(new_player_y - state.flags.y) + x_dist = jnp.abs(player_x - flag_x) + same_road = (current_road == state.flags.road) + + new_collections = jnp.logical_and( + jnp.logical_and(y_dist < self.consts.COLLISION_THRESHOLD, x_dist < self.consts.COLLISION_THRESHOLD), + jnp.logical_and(same_road, ~state.flags.collected) + ) # Update flags collected state new_flags_collected = jnp.logical_or(state.flags.collected, new_collections) @@ -929,7 +986,7 @@ def _death_step(self, state: UpNDownState) -> UpNDownState: def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: up = jnp.logical_or(action == Action.UP, action == Action.UPFIRE) down = jnp.logical_or(action == Action.DOWN, action == Action.DOWNFIRE) - jump = jnp.logical_or(action == Action.FIRE, jnp.logical_or(action == Action.UPFIRE, action == Action.DOWNFIRE)) + jump_pressed = jnp.logical_or(action == Action.FIRE, jnp.logical_or(action == Action.UPFIRE, action == Action.DOWNFIRE)) # Check if on a steep road section FIRST (before applying speed changes) is_on_steep_road = self._is_steep_road_segment( @@ -952,19 +1009,34 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: # Start with current speed player_speed = state.player_car.speed - # === STEEP ROAD BLOCKING LOGIC === + # === FRICTION & MOMENTUM LOGIC === + is_accelerating = up + is_braking = down + + # No friction - speed stays constant when no input + # Speed changes gradually (periodically, not every frame) + should_change_speed = (state.step_counter % self.consts.ACCELERATION_INTERVAL) == 0 + + # === ACCELERATION (UP) === # On steep road: UP action has NO effect (can't accelerate while on steep section) - # Apply UP acceleration only if NOT on steep road (or if jumping over it) can_accelerate = jnp.logical_not(on_steep_not_jumping) + player_speed = jnp.where( - jnp.logical_and(jnp.logical_and(player_speed < self.consts.MAX_SPEED, up), can_accelerate), + jnp.logical_and( + jnp.logical_and(should_change_speed, is_accelerating), + jnp.logical_and(player_speed < self.consts.MAX_SPEED, can_accelerate) + ), player_speed + 1, player_speed, ) + # === BRAKING (DOWN) === # DOWN action always works (can brake/reverse) player_speed = jnp.where( - jnp.logical_and(player_speed > -self.consts.MAX_SPEED, down), + jnp.logical_and( + jnp.logical_and(should_change_speed, is_braking), + player_speed > -self.consts.MAX_SPEED + ), player_speed - 1, player_speed, ) @@ -983,9 +1055,13 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: # Check if player has reached halfway point (50% progress through segment) past_halfway = steep_progress >= 0.5 + # Check if player has enough momentum to climb steep road + MIN_CLIMB_SPEED = 5 + has_momentum = player_speed >= MIN_CLIMB_SPEED + # Two behaviors based on progress: # 1. Before halfway: gradually reduce speed using timer - # 2. At/past halfway: immediately set speed to -2 (slide back) + # 2. At/past halfway: immediately slide back UNLESS we have enough momentum # Before halfway: reduce speed periodically using timer should_reduce_speed = jnp.logical_and( @@ -1007,18 +1083,25 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: steep_road_timer, ) - # At/past halfway: force speed to -2 (slide back down) - should_slide_back = jnp.logical_and(on_steep_going_up, past_halfway) + # At/past halfway: force speed to -2 (slide back down) IF momentum is lost + should_slide_back = jnp.logical_and( + on_steep_going_up, + jnp.logical_and(past_halfway, jnp.logical_not(has_momentum)) + ) player_speed = jnp.where( should_slide_back, jnp.int32(-3), player_speed, ) - can_start_jump = jnp.logical_and(state.jump_cooldown == 0, state.post_jump_cooldown == 0) + # === JUMP LOGIC === + can_start_jump = jnp.logical_and( + state.jump_cooldown == 0, + jnp.logical_and(state.post_jump_cooldown == 0, state.jump_key_released) + ) is_jumping = jnp.logical_or( jnp.logical_and(state.is_jumping, state.jump_cooldown > 0), - jnp.logical_and(state.is_on_road,jnp.logical_and(can_start_jump, jump)), + jnp.logical_and(state.is_on_road, jnp.logical_and(player_speed >= 0, jnp.logical_and(can_start_jump, jump_pressed))), ) # Detect when a new jump is starting (was not jumping, now is jumping) @@ -1067,11 +1150,18 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: # Lock slope at jump start, keep previous slope during jump (use jnp.where) jump_slope = jnp.where(starting_jump, new_jump_slope, state.jump_slope) + # Calculate dynamic jump duration based on speed + # Faster speed = shorter jump duration (covering gap faster) + # Increased base duration for more "air time" as requested + # Formula: 48 - 2 * abs(speed) -> Speed 8 = 32 frames (was 24 before) + current_jump_duration = 48 - 2 * jnp.abs(player_speed) + jump_duration = jnp.where(starting_jump, current_jump_duration.astype(jnp.int32), state.jump_total_duration) + # Use jnp.where for branchless execution of jump_cooldown jump_cooldown = jnp.where( state.jump_cooldown > 0, state.jump_cooldown - 1, - jnp.where(is_jumping, self.consts.JUMP_FRAMES, 0), + jnp.where(is_jumping, jump_duration, 0), ) # Use jnp.where for branchless execution of post_jump_cooldown @@ -1084,6 +1174,13 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: is_on_road = ~is_jumping is_landing = is_landing_now + # Calculate jump progress for magnetism + # Progress = (Total - Remaining) / Total + # Use jnp.maximum(..., 1.0) to avoid division by zero + safe_total_duration = jnp.maximum(state.jump_total_duration, 1.0) + jump_progress = (safe_total_duration - jump_cooldown.astype(jnp.float32)) / safe_total_duration + jump_progress = jnp.clip(jump_progress, 0.0, 1.0) + updated_player_car = self._advance_player_car( position_x=state.player_car.position.x, position_y=state.player_car.position.y, @@ -1098,12 +1195,16 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: car_type=state.player_car.type, is_landing=is_landing, stored_jump_slope=jump_slope, + jump_progress=jump_progress, ) # Check if a speed-changing action (UP or DOWN) was taken speed_action_taken = jnp.logical_or(up, down) # Round starts only after a speed-changing action round_started_now = jnp.logical_or(state.round_started, speed_action_taken) + + # Track jump key release for preventing held-key jumps + next_jump_key_released = jnp.logical_not(jump_pressed) next_state = state._replace( jump_cooldown=jump_cooldown, @@ -1116,6 +1217,8 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: movement_steps=jnp.where(round_started_now, state.movement_steps + 1, state.movement_steps), steep_road_timer=steep_road_timer, jump_slope=jump_slope, + jump_key_released=next_jump_key_released, + jump_total_duration=jump_duration, ) water_crash = jnp.logical_and(is_landing, updated_player_car.current_road == 2) @@ -1158,12 +1261,34 @@ def _flag_step_main(self, state: UpNDownState) -> UpNDownState: ) @partial(jax.jit, static_argnums=(0,)) - def _completion_bonus_step(self, state: UpNDownState) -> UpNDownState: - """Award bonus when all flags are collected.""" + def _level_progression_step(self, state: UpNDownState) -> UpNDownState: + """Handle level completion: award bonus and reset flags.""" all_flags_collected = jnp.all(state.flags_collected_mask) - # Use jnp.where for branchless execution + bonus = jnp.where(all_flags_collected, self.consts.ALL_FLAGS_BONUS, 0) - return state._replace(score=state.score + bonus) + + # Reset flags if all collected + new_collected = jnp.where(all_flags_collected, jnp.zeros_like(state.flags.collected), state.flags.collected) + new_mask = jnp.where(all_flags_collected, jnp.zeros_like(state.flags_collected_mask), state.flags_collected_mask) + + updated_flags = state.flags._replace(collected=new_collected) + + return state._replace( + score=state.score + bonus, + flags=updated_flags, + flags_collected_mask=new_mask + ) + + @partial(jax.jit, static_argnums=(0,)) + def _extra_life_step(self, state: UpNDownState) -> UpNDownState: + """Award extra life every 10000 points.""" + next_milestone = state.last_extra_life_score + self.consts.EXTRA_LIFE_THRESHOLD + should_award = state.score >= next_milestone + + new_lives = jnp.where(should_award, state.lives + 1, state.lives) + new_last_score = jnp.where(should_award, next_milestone, state.last_extra_life_score) + + return state._replace(lives=new_lives, last_extra_life_score=new_last_score) @partial(jax.jit, static_argnums=(0,)) def _collectible_step_main(self, state: UpNDownState) -> UpNDownState: @@ -1473,6 +1598,9 @@ def _respawn_after_collision(self, state: UpNDownState, new_lives: chex.Array) - awaiting_respawn=jnp.array(False), awaiting_round_start=jnp.array(True), # Wait for input to start round after respawn input_released=jnp.array(False), # Require button release before round can start + jump_key_released=jnp.array(True), + last_extra_life_score=state.last_extra_life_score, + jump_total_duration=jnp.array(self.consts.JUMP_FRAMES, dtype=jnp.int32), rng_key=rng_key, ) @@ -1496,9 +1624,11 @@ def _enemy_collision_step_main(self, state: UpNDownState) -> UpNDownState: # For ground collision: only trigger when enemy position is within tight distance overlap_x_ground = dx <= self.consts.GROUND_COLLISION_DISTANCE overlap_y_ground = wrapped_dy <= self.consts.GROUND_COLLISION_DISTANCE - # For late jump collision: use larger overlap based on car dimensions - overlap_x_jump = dx <= (state.player_car.position.width + state.enemy_cars.position.width) / 2.0 - overlap_y_jump = wrapped_dy <= (state.player_car.position.height + state.enemy_cars.position.height) / 2.0 + # For late jump collision: use larger overlap based on car dimensions plus extra tolerance + # "slightly more forgiving" + jump_tolerance = 4.0 + overlap_x_jump = dx <= (state.player_car.position.width + state.enemy_cars.position.width) / 2.0 + jump_tolerance + overlap_y_jump = wrapped_dy <= (state.player_car.position.height + state.enemy_cars.position.height) / 2.0 + jump_tolerance same_road = state.enemy_cars.current_road == state.player_car.current_road # Ground collision mask uses tight 3-pixel distance and same road @@ -1650,6 +1780,9 @@ def _reset_jit(self, key: chex.PRNGKey) -> Tuple[UpNDownObservation, UpNDownStat awaiting_respawn=jnp.array(False), awaiting_round_start=jnp.array(True), # Start frozen until first input input_released=jnp.array(True), # Can start immediately at game start + jump_key_released=jnp.array(True), + last_extra_life_score=jnp.array(0, dtype=jnp.int32), + jump_total_duration=jnp.array(self.consts.JUMP_FRAMES, dtype=jnp.int32), ) initial_obs = self._get_observation(state) return initial_obs, state @@ -1702,7 +1835,8 @@ def run_game_logic(s): s = self._death_step(s) s = self._passive_score_step_main(s) s = self._flag_step_main(s) - s = self._completion_bonus_step(s) + s = self._level_progression_step(s) + s = self._extra_life_step(s) s = self._collectible_step_main(s) s = self._enemy_step_main(s) s = self._enemy_collision_step_main(s) @@ -2131,9 +2265,9 @@ def _compute_flag_palette_ids(self) -> jnp.ndarray: return jnp.array([self._find_palette_id(color) for color in self.consts.FLAG_COLORS], dtype=jnp.int32) @partial(jax.jit, static_argnums=(0,)) - def _jump_arc_offset(self, jump_cooldown: chex.Array) -> chex.Array: + def _jump_arc_offset(self, jump_cooldown: chex.Array, total_duration: chex.Array) -> chex.Array: """Return a simple parabolic jump height based on remaining jump frames.""" - total = jnp.array(self.consts.JUMP_FRAMES, dtype=jnp.float32) + total = total_duration.astype(jnp.float32) remaining = jnp.array(jump_cooldown, dtype=jnp.float32) progress = jnp.clip((total - remaining) / jnp.maximum(total, 1.0), 0.0, 1.0) centered = (progress - 0.5) * 2.0 @@ -2195,13 +2329,20 @@ def select_enemy_mask(enemy_type: chex.Array, going_left: chex.Array): right_mask = self.enemy_right_masks[enemy_type] return jnp.where(going_left, left_mask, right_mask) + # Pre-cast enemy properties to optimal types for rendering BEFORE the scan loop + enemy_active_arr = state.enemy_cars.active + enemy_x_arr = state.enemy_cars.position.x.astype(jnp.int32) + enemy_y_arr = state.enemy_cars.position.y + enemy_type_arr = state.enemy_cars.type + enemy_direction_x_arr = state.enemy_cars.direction_x + def render_enemy(carry, enemy_idx): raster = carry - enemy_active = state.enemy_cars.active[enemy_idx] - enemy_x = state.enemy_cars.position.x[enemy_idx] - enemy_y = state.enemy_cars.position.y[enemy_idx] - enemy_type = state.enemy_cars.type[enemy_idx] - direction_x = state.enemy_cars.direction_x[enemy_idx] + enemy_active = enemy_active_arr[enemy_idx] + enemy_x = enemy_x_arr[enemy_idx] + enemy_y = enemy_y_arr[enemy_idx] + enemy_type = enemy_type_arr[enemy_idx] + direction_x = enemy_direction_x_arr[enemy_idx] screen_y = 105 + (enemy_y - state.player_car.position.y) # Hide enemies when awaiting round start or awaiting respawn should_hide = jnp.logical_or(state.awaiting_round_start, state.awaiting_respawn) @@ -2213,7 +2354,7 @@ def render_enemy(carry, enemy_idx): raster = jax.lax.cond( is_visible, - lambda r: self.jr.render_at(r, enemy_x.astype(jnp.int32), screen_y.astype(jnp.int32), enemy_mask), + lambda r: self.jr.render_at(r, enemy_x, screen_y.astype(jnp.int32), enemy_mask), lambda r: r, operand=raster, ) @@ -2223,7 +2364,7 @@ def render_enemy(carry, enemy_idx): jump_offset = jax.lax.cond( state.is_jumping, - lambda _: self._jump_arc_offset(state.jump_cooldown), + lambda _: self._jump_arc_offset(state.jump_cooldown, state.jump_total_duration), lambda _: jnp.array(0.0, dtype=jnp.float32), operand=None, ) From 2916b045ba6c06dc2004aec4f829b55913552a40 Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Thu, 5 Mar 2026 14:54:53 +0100 Subject: [PATCH 70/76] add initial mod implementation --- src/jaxatari/games/jax_upndown.py | 73 ++++++++++++++-- .../games/mods/upndown_mod_plugins.py | 83 +++++++++++++++++++ 2 files changed, 147 insertions(+), 9 deletions(-) create mode 100644 src/jaxatari/games/mods/upndown_mod_plugins.py diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 5b903ab0b..773e2e1b2 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -333,6 +333,46 @@ def _apply_steep_road_penalty( return final_speed, final_timer, jump_boost + @partial(jax.jit, static_argnums=(0,)) + def _sample_enemy_spawn_road(self, rng_key: chex.PRNGKey) -> chex.Array: + """Sample road index for enemy spawns. + + Extracted as a modding hook; default behavior is unchanged. + """ + return jax.random.randint(rng_key, shape=(), minval=0, maxval=2).astype(jnp.int32) + + @partial(jax.jit, static_argnums=(0,)) + def _collectible_score_values(self, state: UpNDownState, collectible_type_ids: chex.Array) -> chex.Array: + """Return score values for collectible types. + + Extracted as a modding hook; default behavior is unchanged. + """ + return self.consts.COLLECTIBLE_SCORES[collectible_type_ids] + + @partial(jax.jit, static_argnums=(0,)) + def _on_level_completed(self, state: UpNDownState) -> UpNDownState: + """Optional callback invoked only when all flags are collected. + + Default is a no-op and preserves existing game behavior. + """ + return state + + @partial(jax.jit, static_argnums=(0,)) + def _jump_speed_allows_start(self, player_speed: chex.Array) -> chex.Array: + """Return whether jump start is allowed for the current speed. + + Extracted as a modding hook; default behavior is unchanged. + """ + return player_speed >= 0 + + @partial(jax.jit, static_argnums=(0,)) + def _adjust_enemy_spawn_timer(self, state: UpNDownState, spawn_timer: chex.Array) -> chex.Array: + """Optional hook to post-process enemy spawn timer. + + Extracted as a modding hook; default behavior is unchanged. + """ + return spawn_timer + @partial(jax.jit, static_argnums=(0,)) def _get_slope_and_intercept_from_indices(self, current_road: chex.Array, road_index_A: chex.Array, road_index_B: chex.Array) -> Tuple[chex.Array, chex.Array]: """Calculate slope and intercept for the current road segment.""" @@ -925,8 +965,8 @@ def check_collision(idx): # Deactivate collected items final_active = jnp.logical_and(active_after_despawn, ~collections) - # Update score - vectorized lookup without vmap overhead - scores = self.consts.COLLECTIBLE_SCORES[spawned_type_id] + # Update score - extracted into hook for easier modding + scores = self._collectible_score_values(state, spawned_type_id) score_delta = jnp.sum(jnp.where(collections, scores, 0)) # Create final collectibles state @@ -1101,7 +1141,13 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: ) is_jumping = jnp.logical_or( jnp.logical_and(state.is_jumping, state.jump_cooldown > 0), - jnp.logical_and(state.is_on_road, jnp.logical_and(player_speed >= 0, jnp.logical_and(can_start_jump, jump_pressed))), + jnp.logical_and( + state.is_on_road, + jnp.logical_and( + self._jump_speed_allows_start(player_speed), + jnp.logical_and(can_start_jump, jump_pressed), + ), + ), ) # Detect when a new jump is starting (was not jumping, now is jumping) @@ -1273,12 +1319,19 @@ def _level_progression_step(self, state: UpNDownState) -> UpNDownState: updated_flags = state.flags._replace(collected=new_collected) - return state._replace( + next_state = state._replace( score=state.score + bonus, flags=updated_flags, flags_collected_mask=new_mask ) + return jax.lax.cond( + all_flags_collected, + lambda s: self._on_level_completed(s), + lambda s: s, + next_state, + ) + @partial(jax.jit, static_argnums=(0,)) def _extra_life_step(self, state: UpNDownState) -> UpNDownState: """Award extra life every 10000 points.""" @@ -1438,7 +1491,7 @@ def _enemy_step_main(self, state: UpNDownState) -> UpNDownState: spawn_side = jax.random.choice(key_spawn_side, jnp.array([-1.0, 1.0])) raw_spawn_y = state.player_car.position.y + spawn_side * spawn_offset spawn_y = -(((raw_spawn_y) * -1) % self.consts.TRACK_LENGTH) - spawn_road = jax.random.randint(key_spawn_direction, shape=(), minval=0, maxval=2) + spawn_road = self._sample_enemy_spawn_road(key_spawn_direction) segment_spawn = self._get_road_segment(spawn_y) spawn_x = jnp.where( @@ -1533,6 +1586,8 @@ def _enemy_step_main(self, state: UpNDownState) -> UpNDownState: age=enemy_age, ) + spawn_timer = self._adjust_enemy_spawn_timer(state, spawn_timer) + return state._replace( enemy_cars=next_enemy_cars, enemy_spawn_timer=spawn_timer, @@ -1991,7 +2046,7 @@ def observation_space(self) -> spaces.Dict: - player_score: int (0-999999) - lives: int (0-5) - is_jumping: int (0 or 1) - - jump_cooldown: int (0-28) + - jump_cooldown: int (0-48) - is_on_steep_road: int (0 or 1) - round_started: int (0 or 1) """ @@ -2003,7 +2058,7 @@ def observation_space(self) -> spaces.Dict: "width": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), "height": spaces.Box(low=0, high=210, shape=(), dtype=jnp.int32), }), - "speed": spaces.Box(low=-6, high=6, shape=(), dtype=jnp.int32), + "speed": spaces.Box(low=-self.consts.MAX_SPEED, high=self.consts.MAX_SPEED, shape=(), dtype=jnp.int32), "type": spaces.Box(low=0, high=3, shape=(), dtype=jnp.int32), "current_road": spaces.Box(low=0, high=2, shape=(), dtype=jnp.int32), "road_index_A": spaces.Box(low=0, high=30, shape=(), dtype=jnp.int32), @@ -2017,7 +2072,7 @@ def observation_space(self) -> spaces.Dict: "width": spaces.Box(low=0, high=160, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), "height": spaces.Box(low=0, high=210, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), }), - "speed": spaces.Box(low=-6, high=6, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), + "speed": spaces.Box(low=-(self.consts.ENEMY_SPEED_MAX + 1), high=(self.consts.ENEMY_SPEED_MAX + 1), shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), "type": spaces.Box(low=0, high=3, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), "current_road": spaces.Box(low=0, high=2, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), "road_index_A": spaces.Box(low=0, high=30, shape=(self.consts.MAX_ENEMY_CARS,), dtype=jnp.int32), @@ -2045,7 +2100,7 @@ def observation_space(self) -> spaces.Dict: "player_score": spaces.Box(low=0, high=999999, shape=(), dtype=jnp.int32), "lives": spaces.Box(low=0, high=5, shape=(), dtype=jnp.int32), "is_jumping": spaces.Box(low=0, high=1, shape=(), dtype=jnp.int32), - "jump_cooldown": spaces.Box(low=0, high=28, shape=(), dtype=jnp.int32), + "jump_cooldown": spaces.Box(low=0, high=48, shape=(), dtype=jnp.int32), "is_on_steep_road": spaces.Box(low=0, high=1, shape=(), dtype=jnp.int32), "round_started": spaces.Box(low=0, high=1, shape=(), dtype=jnp.int32), }) diff --git a/src/jaxatari/games/mods/upndown_mod_plugins.py b/src/jaxatari/games/mods/upndown_mod_plugins.py new file mode 100644 index 000000000..5382e0efa --- /dev/null +++ b/src/jaxatari/games/mods/upndown_mod_plugins.py @@ -0,0 +1,83 @@ +from functools import partial +import chex +import jax +import jax.numpy as jnp + +from jaxatari.games.jax_upndown import UpNDownState +from jaxatari.modification import JaxAtariInternalModPlugin + + +class RemoveStepRoadsMod(JaxAtariInternalModPlugin): + @partial(jax.jit, static_argnums=(0,)) + def _is_steep_road_segment(self, current_road: chex.Array, road_index_A: chex.Array, road_index_B: chex.Array) -> chex.Array: + return jnp.array(False) + + +class HigherPlayerSpeedMod(JaxAtariInternalModPlugin): + constants_overrides = { + "MAX_SPEED": 9, + } + + +class MoreCollectiblesMod(JaxAtariInternalModPlugin): + constants_overrides = { + "MAX_COLLECTIBLES": 4, + "COLLECTIBLE_SPAWN_INTERVAL": 120, + } + + +class MinCarSpawnGapMod(JaxAtariInternalModPlugin): + conflicts_with = ["progressive_car_spawn_rate"] + constants_overrides = { + "ENEMY_SPAWN_INTERVAL_BASE": 50, + } + + +class AllowJumpBackwardsMod(JaxAtariInternalModPlugin): + @partial(jax.jit, static_argnums=(0,)) + def _jump_speed_allows_start(self, player_speed: chex.Array) -> chex.Array: + return jnp.array(True) + + +class SingleLaneCarSpawnMod(JaxAtariInternalModPlugin): + @partial(jax.jit, static_argnums=(0,)) + def _sample_enemy_spawn_road(self, rng_key: chex.PRNGKey) -> chex.Array: + return jnp.array(1, dtype=jnp.int32) + + +class ProgressiveCarSpawnRateMod(JaxAtariInternalModPlugin): + conflicts_with = ["minimum_car_spawn_gap"] + + @partial(jax.jit, static_argnums=(0,)) + def _adjust_enemy_spawn_timer(self, state: UpNDownState, spawn_timer: chex.Array) -> chex.Array: + start_interval = jnp.int32(self._env.consts.ENEMY_SPAWN_INTERVAL_BASE) + min_interval = jnp.int32(8) + horizon = jnp.float32(1800.0) + + progress = jnp.clip(state.movement_steps.astype(jnp.float32) / horizon, 0.0, 1.0) + decayed_interval = jnp.round( + start_interval.astype(jnp.float32) - progress * (start_interval.astype(jnp.float32) - min_interval.astype(jnp.float32)) + ).astype(jnp.int32) + + target_interval = jnp.maximum(min_interval, decayed_interval) + return jnp.minimum(spawn_timer, target_interval) + + @partial(jax.jit, static_argnums=(0,)) + def _on_level_completed(self, state: UpNDownState) -> UpNDownState: + return state._replace( + movement_steps=jnp.array(0, dtype=jnp.int32), + enemy_spawn_timer=jnp.array(self._env.consts.ENEMY_SPAWN_INTERVAL_BASE, dtype=jnp.int32), + ) + + +class TimeDecayCollectibleValueMod(JaxAtariInternalModPlugin): + @partial(jax.jit, static_argnums=(0,)) + def _on_level_completed(self, state: UpNDownState) -> UpNDownState: + return state._replace(movement_steps=jnp.array(0, dtype=jnp.int32)) + + @partial(jax.jit, static_argnums=(0,)) + def _collectible_score_values(self, state: UpNDownState, collectible_type_ids: chex.Array) -> chex.Array: + base_scores = self._env.consts.COLLECTIBLE_SCORES[collectible_type_ids] + elapsed_decay = jnp.floor(state.movement_steps.astype(jnp.float32) / 200.0).astype(jnp.int32) + min_scores = jnp.maximum(jnp.int32(10), base_scores // 3) + return jnp.maximum(base_scores - elapsed_decay, min_scores) From 68443a2ba2955775dfb6ebe4e6b038cbd51eee60 Mon Sep 17 00:00:00 2001 From: shaik05 Date: Fri, 6 Mar 2026 13:38:07 +0100 Subject: [PATCH 71/76] Allow backward jumping and remove steep road --- src/jaxatari/games/jax_upndown.py | 34 +++++++++++++++++++------------ 1 file changed, 21 insertions(+), 13 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 773e2e1b2..2af620f63 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -1026,18 +1026,26 @@ def _death_step(self, state: UpNDownState) -> UpNDownState: def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: up = jnp.logical_or(action == Action.UP, action == Action.UPFIRE) down = jnp.logical_or(action == Action.DOWN, action == Action.DOWNFIRE) - jump_pressed = jnp.logical_or(action == Action.FIRE, jnp.logical_or(action == Action.UPFIRE, action == Action.DOWNFIRE)) - - # Check if on a steep road section FIRST (before applying speed changes) - is_on_steep_road = self._is_steep_road_segment( - state.player_car.current_road, - state.player_car.road_index_A, - state.player_car.road_index_B, + jump = jnp.logical_or(action == Action.FIRE, jnp.logical_or(action == Action.UPFIRE, action == Action.DOWNFIRE)) + player_speed = state.player_car.speed.astype(jnp.int32) + + player_speed = jax.lax.cond( + jnp.logical_and(state.player_car.speed < self.consts.MAX_SPEED, up), + lambda s: s + 1, + lambda s: s, + operand=player_speed, ) - - # Calculate progress through steep segment (0.0 = bottom, 1.0 = top) - steep_progress = self._get_steep_segment_progress( - state.player_car.position.y, + + player_speed = jax.lax.cond( + jnp.logical_and(state.player_car.speed > -self.consts.MAX_SPEED, down), + lambda s: s - 1, + lambda s: s, + operand=player_speed, + ) + + # Check if on a steep road section (no X direction change) and apply speed reduction + # This simulates steep road sections that require a jump to pass when going upward + is_on_steep_road = self._is_steep_road_segment( state.player_car.current_road, state.player_car.road_index_A, state.player_car.road_index_B, @@ -1063,7 +1071,7 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: player_speed = jnp.where( jnp.logical_and( - jnp.logical_and(should_change_speed, is_accelerating), + jnp.logical_and(up,True), jnp.logical_and(player_speed < self.consts.MAX_SPEED, can_accelerate) ), player_speed + 1, @@ -1620,7 +1628,7 @@ def _respawn_after_collision(self, state: UpNDownState, new_lives: chex.Array) - width=self.consts.PLAYER_SIZE[0], height=self.consts.PLAYER_SIZE[1], ), - speed=jnp.array(0.0, dtype=jnp.float32), + speed=jnp.array(0, dtype=jnp.int32), direction_x=jnp.array(0, dtype=jnp.int32), current_road=respawn_road, road_index_A=start_segment, From 3e546048a362d0778c21bf83cd39b22d522d03be Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sat, 21 Mar 2026 20:49:28 +0100 Subject: [PATCH 72/76] go back to last running version --- src/jaxatari/games/jax_upndown.py | 118 ++++-------------------------- 1 file changed, 15 insertions(+), 103 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 2af620f63..4f7a3af1f 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -50,11 +50,7 @@ class UpNDownConstants(NamedTuple): LANDING_COLLISION_DISTANCE: float = 12.0 # Larger collision distance when landing (increased for easier enemy kills) GROUND_COLLISION_DISTANCE: float = 3.0 # Tight collision distance for ground collisions LATE_JUMP_ENEMY_SCORE: int = 400 - STEEP_ROAD_SPEED_REDUCTION_INTERVAL: int = 12 # Frames between each speed reduction on steep roads - STEEP_ROAD_MIN_SPEED: float = -2.0 # Minimum speed on steep roads - STEEP_ROAD_JUMP_BOOST: float = 1.5 # Multiplier for jump height on steep roads - STEEP_ROAD_RECOVERY_BOOST: float = 0.8 # Speed boost after leaving steep road - STEEP_ROAD_COOLDOWN: int = 5 + STEEP_ROAD_SPEED_REDUCTION_INTERVAL: int = 8 # Frames between each speed reduction on steep roads PASSIVE_SCORE_INTERVAL: int = 60 # Steps between passive score awards PASSIVE_SCORE_AMOUNT: int = 10 # Points awarded for passive scoring COLLISION_THRESHOLD: float = 5.0 # Distance threshold for flag/collectible collision @@ -256,8 +252,8 @@ def _compute_movement_timing(self, speed: chex.Array, step_counter: chex.Array) Returns: Tuple of (move_y, move_x, step_size, speed_sign) """ - abs_speed = jnp.abs(speed).astype(jnp.int32) - speed_index = jnp.minimum(abs_speed, self._speed_dividers.shape[0] - 1).astype(jnp.int32) + abs_speed = jnp.abs(speed) + speed_index = jnp.minimum(abs_speed, jnp.int32(self._speed_dividers.shape[0] - 1)) speed_divider = self._speed_dividers[speed_index] effective_divider = jnp.maximum(1, speed_divider) period = jnp.maximum(1, 16 // effective_divider) @@ -269,69 +265,6 @@ def _compute_movement_timing(self, speed: chex.Array, step_counter: chex.Array) step_size = jnp.where(speed_index >= 6, 1.5 + (speed_index - 6) * 0.2, 1.0) return move_y, move_x, step_size, speed_sign - def _apply_steep_road_penalty( - self, - speed: chex.Array, - is_on_steep_road: chex.Array, - steep_road_timer: chex.Array, - is_jumping: chex.Array, - jump_cooldown: chex.Array, - ) -> Tuple[chex.Array, chex.Array, chex.Array]: - """ - Apply enhanced steep road penalty with perfect balance and edge case handling. - - - Dynamically reduces speed on steep roads when going upward. - - Provides jump boost and recovery for better flow. - - Includes cooldown to prevent rapid reductions. - - Returns: (new_speed, new_timer, jump_boost_multiplier) - """ - going_up = speed > 0 - on_steep_going_up = jnp.logical_and(is_on_steep_road, going_up) - in_cooldown = steep_road_timer < 0 # Negative timer indicates cooldown - - # Increment timer only if not in cooldown and on steep road going up - timer_increment = jax.lax.cond( - jnp.logical_and(on_steep_going_up, jnp.logical_not(in_cooldown)), - lambda _: 1, - lambda _: 0, - operand=None, - ) - new_timer = steep_road_timer + timer_increment - - # Apply reduction when timer reaches interval and not in cooldown - should_reduce = jnp.logical_and( - on_steep_going_up, - jnp.logical_and(new_timer >= self.consts.STEEP_ROAD_SPEED_REDUCTION_INTERVAL, jnp.logical_not(in_cooldown)) - ) - - # Proportional reduction: stronger for higher speeds, with minimum cap - reduction_factor = jnp.maximum(0.05, speed * 0.15) # 5-15% of speed - reduced_speed = jnp.maximum(speed - reduction_factor, self.consts.STEEP_ROAD_MIN_SPEED) - - # Set cooldown after reduction (negative timer) - final_timer = jax.lax.cond( - should_reduce, - lambda _: -self.consts.STEEP_ROAD_COOLDOWN, - lambda _: new_timer, - operand=None, - ) - - # Recovery boost after leaving steep road (not jumping) - just_left_steep = jnp.logical_and(jnp.logical_not(on_steep_going_up), jnp.logical_not(is_jumping)) - recovery_boost = jax.lax.cond(just_left_steep, lambda _: self.consts.STEEP_ROAD_RECOVERY_BOOST, lambda _: 0.0, operand=None) - - # Jump boost if jumping on steep road - jump_boost = jax.lax.cond( - jnp.logical_and(on_steep_going_up, jump_cooldown > 0), - lambda _: self.consts.STEEP_ROAD_JUMP_BOOST, - lambda _: 1.0, - operand=None, - ) - - final_speed = jax.lax.cond(should_reduce, lambda _: reduced_speed + recovery_boost, lambda _: speed + recovery_boost, operand=None) - - return final_speed, final_timer, jump_boost @partial(jax.jit, static_argnums=(0,)) def _sample_enemy_spawn_road(self, rng_key: chex.PRNGKey) -> chex.Array: @@ -1026,31 +959,23 @@ def _death_step(self, state: UpNDownState) -> UpNDownState: def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: up = jnp.logical_or(action == Action.UP, action == Action.UPFIRE) down = jnp.logical_or(action == Action.DOWN, action == Action.DOWNFIRE) - jump = jnp.logical_or(action == Action.FIRE, jnp.logical_or(action == Action.UPFIRE, action == Action.DOWNFIRE)) - player_speed = state.player_car.speed.astype(jnp.int32) - - player_speed = jax.lax.cond( - jnp.logical_and(state.player_car.speed < self.consts.MAX_SPEED, up), - lambda s: s + 1, - lambda s: s, - operand=player_speed, - ) - - player_speed = jax.lax.cond( - jnp.logical_and(state.player_car.speed > -self.consts.MAX_SPEED, down), - lambda s: s - 1, - lambda s: s, - operand=player_speed, - ) - - # Check if on a steep road section (no X direction change) and apply speed reduction - # This simulates steep road sections that require a jump to pass when going upward + jump_pressed = jnp.logical_or(action == Action.FIRE, jnp.logical_or(action == Action.UPFIRE, action == Action.DOWNFIRE)) + + # Check if on a steep road section FIRST (before applying speed changes) is_on_steep_road = self._is_steep_road_segment( state.player_car.current_road, state.player_car.road_index_A, state.player_car.road_index_B, ) + # Calculate progress through steep segment (0.0 = bottom, 1.0 = top) + steep_progress = self._get_steep_segment_progress( + state.player_car.position.y, + state.player_car.current_road, + state.player_car.road_index_A, + state.player_car.road_index_B, + ) + # Determine if player is on steep road going up (not jumping) on_steep_not_jumping = jnp.logical_and(is_on_steep_road, jnp.logical_not(state.is_jumping)) @@ -1071,7 +996,7 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: player_speed = jnp.where( jnp.logical_and( - jnp.logical_and(up,True), + jnp.logical_and(should_change_speed, is_accelerating), jnp.logical_and(player_speed < self.consts.MAX_SPEED, can_accelerate) ), player_speed + 1, @@ -2160,19 +2085,6 @@ def __init__(self, consts: UpNDownConstants = None): channels=3, #downscale=(84, 84) ) - def _createBackgroundSprite(self, dimensions: Tuple[int, int]) -> jnp.ndarray: - - height, width = dimensions - # Create a vertical gradient: blue at top, lighter blue at bottom - top_color = jnp.array([135, 206, 235, 255], dtype=jnp.uint8) # Sky blue - bottom_color = jnp.array([173, 216, 230, 255], dtype=jnp.uint8) # Lighter sky blue - - # Linear interpolation for gradient - y_coords = jnp.arange(height, dtype=jnp.float32) / (height - 1) - gradient = jnp.outer(y_coords, bottom_color - top_color) + top_color - gradient = jnp.clip(gradient, 0, 255).astype(jnp.uint8) - - return gradient self.jr = render_utils.JaxRenderingUtils(self.config) background = self._createBackgroundSprite(self.config.game_dimensions) From 18cbb4dc62c5c09532bd4a304c47c19ff2059fbd Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Sun, 22 Mar 2026 20:42:00 +0100 Subject: [PATCH 73/76] fix some failing pipeline tests --- src/jaxatari/games/jax_upndown.py | 61 ++++++++++++++++++++++--------- 1 file changed, 44 insertions(+), 17 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 4f7a3af1f..37d5c0b1c 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -6,15 +6,16 @@ import jax.lax import jax.numpy as jnp import chex +from flax import struct import jaxatari.spaces as spaces from jaxatari.renderers import JAXGameRenderer from jaxatari.rendering import jax_rendering_utils as render_utils from jaxatari.environment import JaxEnvironment, JAXAtariAction as Action -class UpNDownConstants(NamedTuple): +class UpNDownConstants(struct.PyTreeNode): FRAME_SKIP: int = 4 - DIFFICULTIES: chex.Array = jnp.array([0, 1, 2, 3, 4, 5]) + DIFFICULTIES: chex.Array = struct.field(default_factory=lambda: jnp.array([0, 1, 2, 3, 4, 5])) MAX_SPEED: int = 7 INITIAL_LIVES: int = 5 JUMP_ARC_HEIGHT: float = 22.0 @@ -57,15 +58,15 @@ class UpNDownConstants(NamedTuple): ACCELERATION_INTERVAL: int = 6 # Frames between speed changes when holding up/down EXTRA_LIFE_THRESHOLD: int = 10000 # Score threshold for extra life TRACK_LENGTH: int = 1036 - FIRST_TRACK_CORNERS_X: chex.Array = jnp.array([30, 75, 128, 75, 21, 75, 131, 111, 150, 95, 150, 115, 150, 108, 150, 115, 115, 75, 18, 38, 67, 38, 38, 20, 64, 30]) - TRACK_CORNERS_Y: chex.Array = jnp.array([0, -40, -98, -155, -203, -268, -327, -347, -382, -467, -525, -565, -597, -625, -670, -705, -738, -788, -838, -862, -898, -925, -950, -972, -1000, -1035]) - SECOND_TRACK_CORNERS_X: chex.Array = jnp.array([115, 75, 20, 75, 133, 75, 22, 37, 63, 27, 66, 30, 63, 24, 60, 38, 38, 75, 131, 111, 150, 118, 118, 98, 150, 115]) + FIRST_TRACK_CORNERS_X: chex.Array = struct.field(default_factory=lambda: jnp.array([30, 75, 128, 75, 21, 75, 131, 111, 150, 95, 150, 115, 150, 108, 150, 115, 115, 75, 18, 38, 67, 38, 38, 20, 64, 30])) + TRACK_CORNERS_Y: chex.Array = struct.field(default_factory=lambda: jnp.array([0, -40, -98, -155, -203, -268, -327, -347, -382, -467, -525, -565, -597, -625, -670, -705, -738, -788, -838, -862, -898, -925, -950, -972, -1000, -1035])) + SECOND_TRACK_CORNERS_X: chex.Array = struct.field(default_factory=lambda: jnp.array([115, 75, 20, 75, 133, 75, 22, 37, 63, 27, 66, 30, 63, 24, 60, 38, 38, 75, 131, 111, 150, 118, 118, 98, 150, 115])) PLAYER_SIZE: Tuple[int, int] = (4, 16) INITIAL_ROAD_POS_Y: int = 25 # Flag constants - 8 flags with different colors matching the top row NUM_FLAGS: int = 8 # Flag colors as RGBA values (matching the top row from left to right) - FLAG_COLORS: chex.Array = jnp.array([ + FLAG_COLORS: chex.Array = struct.field(default_factory=lambda: jnp.array([ [184, 50, 50, 255], # Red [181, 83, 40, 255], # Orange [162, 98, 33, 255], # Dark orange @@ -74,14 +75,14 @@ class UpNDownConstants(NamedTuple): [168, 48, 143, 255], # Magenta [125, 48, 173, 255], # Purple [78, 50, 181, 255], # Blue - ]) + ])) # Top display positions for each flag (x coordinates where blackout squares appear) - FLAG_TOP_X_POSITIONS: chex.Array = jnp.array([13, 30, 47, 64, 82, 98, 118, 134]) + FLAG_TOP_X_POSITIONS: chex.Array = struct.field(default_factory=lambda: jnp.array([13, 30, 47, 64, 82, 98, 118, 134])) FLAG_TOP_Y: int = 20 FLAG_BLACKOUT_SIZE: Tuple[int, int] = (14, 14) # Size of blackout square FLAG_COLLECTION_SCORE: int = 75 # Points awarded for collecting a flag # Life display constants - positions of life cars at the bottom - LIFE_BOTTOM_X_POSITIONS: chex.Array = jnp.array([13, 18, 25, 33, 33]) # X positions for 5 life cars + LIFE_BOTTOM_X_POSITIONS: chex.Array = struct.field(default_factory=lambda: jnp.array([13, 18, 25, 33, 33])) # X positions for 5 life cars LIFE_BOTTOM_Y: int = 195 # Collectible constants - unified dynamic spawning MAX_COLLECTIBLES: int = 1 # Maximum collectibles that can exist at once (pool of mixed types) @@ -93,11 +94,23 @@ class UpNDownConstants(NamedTuple): COLLECTIBLE_TYPE_LOLLYPOP: int = 2 COLLECTIBLE_TYPE_ICE_CREAM: int = 3 # Collectible type spawn probabilities (cumulative thresholds for random sampling) - COLLECTIBLE_SPAWN_PROBABILITIES: chex.Array = jnp.array([35, 65, 90, 100], dtype=jnp.int32) # Cherry: 35%, Balloon: 30%, Lollypop: 25%, IceCream: 10% + COLLECTIBLE_SPAWN_PROBABILITIES: chex.Array = struct.field(default_factory=lambda: jnp.array([35, 65, 90, 100], dtype=jnp.int32)) # Cherry: 35%, Balloon: 30%, Lollypop: 25%, IceCream: 10% # Collectible type scores - COLLECTIBLE_SCORES: chex.Array = jnp.array([50, 65, 70, 75], dtype=jnp.int32) # [cherry, balloon, lollypop, ice_cream] + COLLECTIBLE_SCORES: chex.Array = struct.field(default_factory=lambda: jnp.array([50, 65, 70, 75], dtype=jnp.int32)) # [cherry, balloon, lollypop, ice_cream] # Shared collectible colors - COLLECTIBLE_COLORS: chex.Array = FLAG_COLORS + COLLECTIBLE_COLORS: chex.Array = struct.field(default_factory=lambda: jnp.array([ + [184, 50, 50, 255], + [181, 83, 40, 255], + [162, 98, 33, 255], + [134, 134, 29, 255], + [200, 72, 72, 255], + [168, 48, 143, 255], + [125, 48, 173, 255], + [78, 50, 181, 255], + ])) + + def _replace(self, **kwargs): + return self.replace(**kwargs) @@ -151,7 +164,8 @@ class EnemyCars(NamedTuple): active: chex.Array age: chex.Array -class UpNDownState(NamedTuple): +@struct.dataclass +class UpNDownState: score: chex.Array difficulty: chex.Array jump_cooldown: chex.Array @@ -187,9 +201,13 @@ class UpNDownState(NamedTuple): last_extra_life_score: chex.Array # Score at which last extra life was awarded jump_total_duration: chex.Array # Total duration of the current/last jump for rendering arc + def _replace(self, **kwargs): + return self.replace(**kwargs) + -class UpNDownObservation(NamedTuple): +@struct.dataclass +class UpNDownObservation: player_car: Car enemy_cars: EnemyCars flags: Flag @@ -202,14 +220,23 @@ class UpNDownObservation(NamedTuple): is_on_steep_road: chex.Array round_started: chex.Array + def _replace(self, **kwargs): + return self.replace(**kwargs) + -class UpNDownInfo(NamedTuple): +@struct.dataclass +class UpNDownInfo: """Additional info for debugging and analysis.""" step_counter: jnp.ndarray # Total steps taken difficulty: jnp.ndarray # Current difficulty level movement_steps: jnp.ndarray # Steps since round started jump_slope: jnp.ndarray # Current jump trajectory slope player_road_segment: jnp.ndarray # Current road segment index + + def _replace(self, **kwargs): + return self.replace(**kwargs) + + class JaxUpNDown(JaxEnvironment[UpNDownState, UpNDownObservation, UpNDownInfo, UpNDownConstants]): def __init__(self, consts: UpNDownConstants = None, reward_funcs: list[callable]=None): consts = consts or UpNDownConstants() @@ -2077,10 +2104,10 @@ def _get_done(self, state: UpNDownState) -> bool: return state.lives <= 0 class UpNDownRenderer(JAXGameRenderer): - def __init__(self, consts: UpNDownConstants = None): + def __init__(self, consts: UpNDownConstants = None, config: render_utils.RendererConfig | None = None): super().__init__() self.consts = consts or UpNDownConstants() - self.config = render_utils.RendererConfig( + self.config = config or render_utils.RendererConfig( game_dimensions=(210, 160), channels=3, #downscale=(84, 84) From 89c68a3363696f31940fdafffb696f9b96b6602f Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Tue, 24 Mar 2026 00:47:24 +0100 Subject: [PATCH 74/76] add level 2 and 3 --- src/jaxatari/games/jax_upndown.py | 630 +++++++++++++----- .../games/mods/upndown/upndown_mod_plugins.py | 19 +- .../up_n_down/background/backround_lvl2_1.npy | Bin 0 -> 100996 bytes .../up_n_down/background/backround_lvl2_2.npy | Bin 0 -> 101664 bytes .../up_n_down/background/backround_lvl2_3.npy | Bin 0 -> 49376 bytes .../up_n_down/background/backround_lvl2_4.npy | Bin 0 -> 80992 bytes .../up_n_down/background/backround_lvl2_5.npy | Bin 0 -> 63968 bytes .../up_n_down/background/backround_lvl2_6.npy | Bin 0 -> 67616 bytes .../up_n_down/background/backround_lvl2_7.npy | Bin 0 -> 38784 bytes .../up_n_down/background/backround_lvl2_8.npy | Bin 0 -> 87680 bytes .../up_n_down/background/backround_lvl2_9.npy | Bin 0 -> 33568 bytes .../up_n_down/background/backround_lvl3_1.npy | Bin 0 -> 101056 bytes .../up_n_down/background/backround_lvl3_2.npy | Bin 0 -> 51808 bytes .../up_n_down/background/backround_lvl3_3.npy | Bin 0 -> 38432 bytes .../up_n_down/background/backround_lvl3_4.npy | Bin 0 -> 93152 bytes .../up_n_down/background/backround_lvl3_5.npy | Bin 0 -> 71264 bytes .../up_n_down/background/backround_lvl3_6.npy | Bin 0 -> 98624 bytes .../up_n_down/background/backround_lvl3_7.npy | Bin 0 -> 65792 bytes .../up_n_down/background/backround_lvl3_8.npy | Bin 0 -> 85856 bytes .../up_n_down/background/backround_lvl3_9.npy | Bin 0 -> 11072 bytes 20 files changed, 469 insertions(+), 180 deletions(-) create mode 100644 src/jaxatari/games/sprites/up_n_down/background/backround_lvl2_1.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/background/backround_lvl2_2.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/background/backround_lvl2_3.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/background/backround_lvl2_4.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/background/backround_lvl2_5.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/background/backround_lvl2_6.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/background/backround_lvl2_7.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/background/backround_lvl2_8.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/background/backround_lvl2_9.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/background/backround_lvl3_1.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/background/backround_lvl3_2.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/background/backround_lvl3_3.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/background/backround_lvl3_4.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/background/backround_lvl3_5.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/background/backround_lvl3_6.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/background/backround_lvl3_7.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/background/backround_lvl3_8.npy create mode 100644 src/jaxatari/games/sprites/up_n_down/background/backround_lvl3_9.npy diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 37d5c0b1c..63041d7e7 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -1,10 +1,12 @@ import os from functools import partial +import re from typing import NamedTuple, Tuple import jax import jax.lax import jax.numpy as jnp +import numpy as np import chex from flax import struct @@ -57,10 +59,27 @@ class UpNDownConstants(struct.PyTreeNode): COLLISION_THRESHOLD: float = 5.0 # Distance threshold for flag/collectible collision ACCELERATION_INTERVAL: int = 6 # Frames between speed changes when holding up/down EXTRA_LIFE_THRESHOLD: int = 10000 # Score threshold for extra life + LEVEL_COUNT: int = 3 TRACK_LENGTH: int = 1036 - FIRST_TRACK_CORNERS_X: chex.Array = struct.field(default_factory=lambda: jnp.array([30, 75, 128, 75, 21, 75, 131, 111, 150, 95, 150, 115, 150, 108, 150, 115, 115, 75, 18, 38, 67, 38, 38, 20, 64, 30])) - TRACK_CORNERS_Y: chex.Array = struct.field(default_factory=lambda: jnp.array([0, -40, -98, -155, -203, -268, -327, -347, -382, -467, -525, -565, -597, -625, -670, -705, -738, -788, -838, -862, -898, -925, -950, -972, -1000, -1035])) - SECOND_TRACK_CORNERS_X: chex.Array = struct.field(default_factory=lambda: jnp.array([115, 75, 20, 75, 133, 75, 22, 37, 63, 27, 66, 30, 63, 24, 60, 38, 38, 75, 131, 111, 150, 118, 118, 98, 150, 115])) + HAZARD_LEVEL_INDEX: int = 2 # Third level (0-based index) + HAZARD_ZONE_1_MIN_Y: float = 557.0 + HAZARD_ZONE_1_MAX_Y: float = 587.0 + HAZARD_ZONE_2_MIN_Y: float = 830.0 + HAZARD_ZONE_2_MAX_Y: float = 860.0 + FIRST_TRACK_CORNERS_X: chex.Array = struct.field(default_factory=lambda: jnp.array([ + [30, 75, 128, 75, 21, 75, 131, 111, 150, 95, 150, 115, 150, 108, 150, 115, 115, 115, 75, 18, 38, 67, 38, 38, 20, 64, 30], + [40, 38, 17, 17, 17, 40, 66, 22, 33, 66, 38, 72, 130, 118, 118, 145, 120, 72, 30, 22, 37, 22, 30, 73, 60, 50, 45], + [16, 35, 16, 28, 16, 71, 130, 95, 145, 100, 145, 100, 115, 75, 75, 102, 102, 75, 22, 22, 40, 22, 34, 65, 65, 35, 16], + ])) + TRACK_CORNERS_Y: chex.Array = struct.field(default_factory=lambda: jnp.array + ([[0, -40, -98, -155, -203, -268, -327, -347, -382, -467, -525, -565, -597, -625, -670, -705, -709, -738, -788, -838, -862, -898, -925, -950, -972, -1000, -1035], + [0, -5, -53, -57, -101, -179, -220, -246, -296, -330, -353, -377, -440, -457, -540, -580, -620, -684, -750, -770, -800, -880, -926, -1000, -1015, -1020, -1026], + [0, -26, -54, -72, -96, -150, -190, -210, -245, -278, -308, -335, -347, -390, -434, -454, -617, -635, -683, -714, -720, -735, -796, -824, -850, -951, -1028]])) + SECOND_TRACK_CORNERS_X: chex.Array = struct.field(default_factory=lambda: jnp.array([ + [115, 75, 20, 75, 133, 75, 22, 37, 63, 27, 66, 30, 63, 24, 60, 38, 38, 38, 75, 131, 111, 150, 118, 118, 98, 150, 115], + [105, 107, 130, 130, 130, 115, 145, 100, 112, 145, 110, 72, 22, 33, 33, 60, 34, 72, 120, 110, 125, 100, 115, 73, 80, 85, 90], + [128, 115, 128, 118, 130, 71, 22, 30, 65, 22, 65, 22, 30, 75, 75, 46, 46, 75, 130, 130, 110, 95, 110, 145, 145, 110, 130], + ])) PLAYER_SIZE: Tuple[int, int] = (4, 16) INITIAL_ROAD_POS_Y: int = 25 # Flag constants - 8 flags with different colors matching the top row @@ -179,6 +198,7 @@ class UpNDownState: step_counter: chex.Array rng_key: chex.PRNGKey round_started: chex.Array + level: chex.Array # Current level index (0, 1, 2) movement_steps: chex.Array steep_road_timer: chex.Array # Timer for steep road speed reduction jump_slope: chex.Array # X movement per Y step, locked at jump start (float) @@ -309,13 +329,91 @@ def _collectible_score_values(self, state: UpNDownState, collectible_type_ids: c """ return self.consts.COLLECTIBLE_SCORES[collectible_type_ids] + @partial(jax.jit, static_argnums=(0,)) + def _get_spawn_segment_for_level(self, level: chex.Array, road: chex.Array) -> chex.Array: + """Pick the first non-steep segment for the requested level/road.""" + corners_a, corners_b = self._get_track_corners_for_level(level) + dx = jnp.where( + road == 0, + corners_a[1:] - corners_a[:-1], + corners_b[1:] - corners_b[:-1], + ) + valid = jnp.abs(dx) >= 1 + first_valid = jnp.argmax(valid.astype(jnp.int32)) + return jnp.where(jnp.any(valid), first_valid.astype(jnp.int32), jnp.int32(0)) + + @partial(jax.jit, static_argnums=(0,)) + def _get_spawn_position_for_level(self, level: chex.Array, road: chex.Array) -> Tuple[chex.Array, chex.Array, chex.Array]: + """Return (segment, y, x) spawn tuple aligned to level geometry.""" + segment = self._get_spawn_segment_for_level(level, road) + corners_y = self._get_track_corners_y_for_level(level) + spawn_y = corners_y[segment].astype(jnp.float32) + corners_a, corners_b = self._get_track_corners_for_level(level) + spawn_x = jnp.where( + road == 0, + self._get_x_on_road(spawn_y, segment, corners_a, corners_y), + self._get_x_on_road(spawn_y, segment, corners_b, corners_y), + ) + return segment, spawn_y, spawn_x + @partial(jax.jit, static_argnums=(0,)) def _on_level_completed(self, state: UpNDownState) -> UpNDownState: - """Optional callback invoked only when all flags are collected. + """Advance to next level and freeze until release+press input starts it.""" + rng_key, enemy_key = jax.random.split(state.rng_key) + next_level = (state.level + jnp.int32(1)) % jnp.int32(self.consts.LEVEL_COUNT) + start_road = jnp.int32(0) + start_segment, player_start_y, start_x = self._get_spawn_position_for_level(next_level, start_road) - Default is a no-op and preserves existing game behavior. - """ - return state + enemy_cars = self._initialize_enemies(enemy_key, player_start_y, next_level) + collectibles = self._initialize_collectibles() + + player_car = state.player_car._replace( + position=EntityPosition( + x=jnp.asarray(start_x, dtype=jnp.float32), + y=jnp.asarray(player_start_y, dtype=jnp.float32), + width=state.player_car.position.width, + height=state.player_car.position.height, + ), + speed=jnp.array(0, dtype=jnp.int32), + direction_x=jnp.array(0, dtype=jnp.int32), + current_road=start_road, + road_index_A=start_segment, + road_index_B=start_segment, + ) + + return state._replace( + level=next_level, + is_dead=jnp.array(False), + jump_cooldown=jnp.array(0, dtype=jnp.int32), + post_jump_cooldown=jnp.array(0, dtype=jnp.int32), + is_jumping=jnp.array(False), + is_on_road=jnp.array(True), + player_car=player_car, + round_started=jnp.array(False), + movement_steps=jnp.array(0), + steep_road_timer=jnp.array(0, dtype=jnp.int32), + jump_slope=jnp.array(0.0, dtype=jnp.float32), + collectibles=collectibles, + collectible_spawn_timer=jnp.array(0, dtype=jnp.int32), + enemy_cars=enemy_cars, + enemy_spawn_timer=jnp.array(self.consts.ENEMY_SPAWN_INTERVAL_BASE, dtype=jnp.int32), + awaiting_respawn=jnp.array(False), + awaiting_round_start=jnp.array(True), + input_released=jnp.array(False), + jump_key_released=jnp.array(True), + jump_total_duration=jnp.array(self.consts.JUMP_FRAMES, dtype=jnp.int32), + rng_key=rng_key, + ) + + @partial(jax.jit, static_argnums=(0,)) + def _get_track_corners_for_level(self, level: chex.Array) -> Tuple[chex.Array, chex.Array]: + """Return road A/B track corners for the requested level index.""" + return self.consts.FIRST_TRACK_CORNERS_X[level], self.consts.SECOND_TRACK_CORNERS_X[level] + + @partial(jax.jit, static_argnums=(0,)) + def _get_track_corners_y_for_level(self, level: chex.Array) -> chex.Array: + """Return track Y corners for the requested level index.""" + return self.consts.TRACK_CORNERS_Y[level] @partial(jax.jit, static_argnums=(0,)) def _jump_speed_allows_start(self, player_speed: chex.Array) -> chex.Array: @@ -334,17 +432,19 @@ def _adjust_enemy_spawn_timer(self, state: UpNDownState, spawn_timer: chex.Array return spawn_timer @partial(jax.jit, static_argnums=(0,)) - def _get_slope_and_intercept_from_indices(self, current_road: chex.Array, road_index_A: chex.Array, road_index_B: chex.Array) -> Tuple[chex.Array, chex.Array]: + def _get_slope_and_intercept_from_indices(self, current_road: chex.Array, road_index_A: chex.Array, road_index_B: chex.Array, level: chex.Array) -> Tuple[chex.Array, chex.Array]: """Calculate slope and intercept for the current road segment.""" + corners_a, corners_b = self._get_track_corners_for_level(level) + corners_y = self._get_track_corners_y_for_level(level) road_index = jnp.where(current_road == 0, road_index_A, road_index_B) x1 = jnp.where(current_road == 0, - self.consts.FIRST_TRACK_CORNERS_X[road_index], - self.consts.SECOND_TRACK_CORNERS_X[road_index]) + corners_a[road_index], + corners_b[road_index]) x2 = jnp.where(current_road == 0, - self.consts.FIRST_TRACK_CORNERS_X[road_index + 1], - self.consts.SECOND_TRACK_CORNERS_X[road_index + 1]) - y1 = self.consts.TRACK_CORNERS_Y[road_index] - y2 = self.consts.TRACK_CORNERS_Y[road_index + 1] + corners_a[road_index + 1], + corners_b[road_index + 1]) + y1 = corners_y[road_index] + y2 = corners_y[road_index + 1] dx = x2 - x1 dy = y2 - y1 @@ -363,10 +463,10 @@ def _is_on_line_for_position(self, position: EntityPosition, slope: chex.Array, ) @partial(jax.jit, static_argnums=(0,)) - def _get_x_on_road(self, y: chex.Array, road_segment: chex.Array, track_corners_x: chex.Array) -> chex.Array: + def _get_x_on_road(self, y: chex.Array, road_segment: chex.Array, track_corners_x: chex.Array, track_corners_y: chex.Array) -> chex.Array: """Calculate the X position on a road given a Y coordinate and road segment.""" - y1 = self.consts.TRACK_CORNERS_Y[road_segment] - y2 = self.consts.TRACK_CORNERS_Y[road_segment + 1] + y1 = track_corners_y[road_segment] + y2 = track_corners_y[road_segment + 1] x1 = track_corners_x[road_segment] x2 = track_corners_x[road_segment + 1] @@ -375,46 +475,75 @@ def _get_x_on_road(self, y: chex.Array, road_segment: chex.Array, track_corners_ return x1 + t * (x2 - x1) @partial(jax.jit, static_argnums=(0,)) - def _get_x_for_road_index(self, y: chex.Array, road_segment: chex.Array, road_index: chex.Array) -> chex.Array: + def _get_x_for_road_index(self, y: chex.Array, road_segment: chex.Array, road_index: chex.Array, level: chex.Array) -> chex.Array: """Get X position on road A (index 0) or road B (index 1) for given Y and segment.""" + corners_a, corners_b = self._get_track_corners_for_level(level) + corners_y = self._get_track_corners_y_for_level(level) track_corners = jnp.where( road_index == 0, - self.consts.FIRST_TRACK_CORNERS_X[road_segment], - self.consts.SECOND_TRACK_CORNERS_X[road_segment], + corners_a[road_segment], + corners_b[road_segment], ) track_corners_next = jnp.where( road_index == 0, - self.consts.FIRST_TRACK_CORNERS_X[road_segment + 1], - self.consts.SECOND_TRACK_CORNERS_X[road_segment + 1], + corners_a[road_segment + 1], + corners_b[road_segment + 1], ) - y1 = self.consts.TRACK_CORNERS_Y[road_segment] - y2 = self.consts.TRACK_CORNERS_Y[road_segment + 1] + y1 = corners_y[road_segment] + y2 = corners_y[road_segment + 1] t = jnp.where(y2 != y1, (y - y1) / (y2 - y1), 0.0) return track_corners + t * (track_corners_next - track_corners) @partial(jax.jit, static_argnums=(0,)) - def _get_road_segment(self, y: chex.Array) -> chex.Array: + def _get_road_segment(self, y: chex.Array, level: chex.Array) -> chex.Array: """Return the road segment index for a given y position.""" - segments = jnp.sum(self.consts.TRACK_CORNERS_Y > y, dtype=jnp.int32) - max_idx = jnp.int32(len(self.consts.TRACK_CORNERS_Y) - 1) + corners_y = self._get_track_corners_y_for_level(level) + segments = jnp.sum(corners_y > y, dtype=jnp.int32) + max_idx = jnp.int32(corners_y.shape[0] - 1) return jnp.clip(segments - 1, 0, max_idx) @partial(jax.jit, static_argnums=(0,)) - def _compute_direction_x(self, current_road: chex.Array, road_index_A: chex.Array, road_index_B: chex.Array) -> chex.Array: + def _to_track_distance_y(self, world_y: chex.Array) -> chex.Array: + """Convert wrapped world Y (negative-forward) to track-distance coordinates.""" + return jnp.mod(-world_y, self.consts.TRACK_LENGTH) + + @partial(jax.jit, static_argnums=(0,)) + def _is_level_hazard_position(self, level: chex.Array, world_y: chex.Array) -> chex.Array: + """Return whether the position is inside configured level hazard zones. + + Supports scalar or vector Y inputs. + """ + track_y = self._to_track_distance_y(world_y) + in_first_hazard = jnp.logical_and( + track_y >= self.consts.HAZARD_ZONE_1_MIN_Y, + track_y <= self.consts.HAZARD_ZONE_1_MAX_Y, + ) + in_second_hazard = jnp.logical_and( + track_y >= self.consts.HAZARD_ZONE_2_MIN_Y, + track_y <= self.consts.HAZARD_ZONE_2_MAX_Y, + ) + return jnp.logical_and( + level == self.consts.HAZARD_LEVEL_INDEX, + jnp.logical_or(in_first_hazard, in_second_hazard), + ) + + @partial(jax.jit, static_argnums=(0,)) + def _compute_direction_x(self, current_road: chex.Array, road_index_A: chex.Array, road_index_B: chex.Array, level: chex.Array) -> chex.Array: """Calculate the X direction for movement on the current road segment. Returns: Direction as int32: -1 for left, 1 for right (defaults to -1 for vertical segments) """ + corners_a, corners_b = self._get_track_corners_for_level(level) # Select the road index based on which road we're on road_index = jnp.where(current_road == 0, road_index_A, road_index_B) # Select corners for the current road x_curr = jnp.where(current_road == 0, - self.consts.FIRST_TRACK_CORNERS_X[road_index], - self.consts.SECOND_TRACK_CORNERS_X[road_index]) + corners_a[road_index], + corners_b[road_index]) x_next = jnp.where(current_road == 0, - self.consts.FIRST_TRACK_CORNERS_X[road_index + 1], - self.consts.SECOND_TRACK_CORNERS_X[road_index + 1]) + corners_a[road_index + 1], + corners_b[road_index + 1]) direction_raw = x_next - x_curr return jnp.where(direction_raw == 0, -1, jnp.sign(direction_raw)).astype(jnp.int32) @@ -450,7 +579,7 @@ def _move_on_road( return new_x, new_y @partial(jax.jit, static_argnums=(0,)) - def _is_steep_road_segment(self, current_road: chex.Array, road_index_A: chex.Array, road_index_B: chex.Array) -> chex.Array: + def _is_steep_road_segment(self, current_road: chex.Array, road_index_A: chex.Array, road_index_B: chex.Array, level: chex.Array) -> chex.Array: """Check if the current road segment is steep (no X direction change). A steep segment is one where the X coordinates of consecutive corners are the same, @@ -458,21 +587,23 @@ def _is_steep_road_segment(self, current_road: chex.Array, road_index_A: chex.Ar Returns True if the segment is steep (requires jump to pass when going up). """ + corners_a, corners_b = self._get_track_corners_for_level(level) # Get the X difference for the current road segment road_index = jnp.where(current_road == 0, road_index_A, road_index_B) x_curr = jnp.where(current_road == 0, - self.consts.FIRST_TRACK_CORNERS_X[road_index], - self.consts.SECOND_TRACK_CORNERS_X[road_index]) + corners_a[road_index], + corners_b[road_index]) x_next = jnp.where(current_road == 0, - self.consts.FIRST_TRACK_CORNERS_X[road_index + 1], - self.consts.SECOND_TRACK_CORNERS_X[road_index + 1]) + corners_a[road_index + 1], + corners_b[road_index + 1]) x_diff = jnp.abs(x_next - x_curr) # A segment is steep if there's no X change (or very small change) return x_diff < 1.0 @partial(jax.jit, static_argnums=(0,)) - def _get_steep_segment_progress(self, position_y: chex.Array, current_road: chex.Array, - road_index_A: chex.Array, road_index_B: chex.Array) -> chex.Array: + def _get_steep_segment_progress(self, position_y: chex.Array, current_road: chex.Array, + road_index_A: chex.Array, road_index_B: chex.Array, + level: chex.Array) -> chex.Array: """Calculate progress (0.0 to 1.0) through the current steep road segment. 0.0 = at the bottom (start) of the steep segment @@ -482,9 +613,10 @@ def _get_steep_segment_progress(self, position_y: chex.Array, current_road: chex but Y decreases as we go forward on the track). """ road_index = jnp.where(current_road == 0, road_index_A, road_index_B) + corners_y = self._get_track_corners_y_for_level(level) # Y coordinates of segment boundaries - y_start = self.consts.TRACK_CORNERS_Y[road_index] # Start of segment (lower Y = further ahead) - y_end = self.consts.TRACK_CORNERS_Y[road_index + 1] # End of segment (higher Y in absolute terms) + y_start = corners_y[road_index] # Start of segment (lower Y = further ahead) + y_end = corners_y[road_index + 1] # End of segment (higher Y in absolute terms) # Calculate progress: how far through the segment are we? # Since Y decreases as we go forward, we need to invert @@ -498,6 +630,7 @@ def _get_steep_segment_progress(self, position_y: chex.Array, current_road: chex @partial(jax.jit, static_argnums=(0,)) def _check_landing_position( self, + level: chex.Array, road_index_A: chex.Array, road_index_B: chex.Array, new_position_x: chex.Array, @@ -508,21 +641,23 @@ def _check_landing_position( Returns: Tuple of (landing_in_water, between_roads, road_A_x, road_B_x) """ + corners_a, corners_b = self._get_track_corners_for_level(level) + corners_y = self._get_track_corners_y_for_level(level) # Calculate X position on road A at the given Y - y_ratio_A = (new_position_y - self.consts.TRACK_CORNERS_Y[road_index_A]) / ( - self.consts.TRACK_CORNERS_Y[road_index_A + 1] - self.consts.TRACK_CORNERS_Y[road_index_A] + y_ratio_A = (new_position_y - corners_y[road_index_A]) / ( + corners_y[road_index_A + 1] - corners_y[road_index_A] ) road_A_x = y_ratio_A * ( - self.consts.FIRST_TRACK_CORNERS_X[road_index_A + 1] - self.consts.FIRST_TRACK_CORNERS_X[road_index_A] - ) + self.consts.FIRST_TRACK_CORNERS_X[road_index_A] + corners_a[road_index_A + 1] - corners_a[road_index_A] + ) + corners_a[road_index_A] # Calculate X position on road B at the given Y - y_ratio_B = (new_position_y - self.consts.TRACK_CORNERS_Y[road_index_B]) / ( - self.consts.TRACK_CORNERS_Y[road_index_B + 1] - self.consts.TRACK_CORNERS_Y[road_index_B] + y_ratio_B = (new_position_y - corners_y[road_index_B]) / ( + corners_y[road_index_B + 1] - corners_y[road_index_B] ) road_B_x = y_ratio_B * ( - self.consts.SECOND_TRACK_CORNERS_X[road_index_B + 1] - self.consts.SECOND_TRACK_CORNERS_X[road_index_B] - ) + self.consts.SECOND_TRACK_CORNERS_X[road_index_B] + corners_b[road_index_B + 1] - corners_b[road_index_B] + ) + corners_b[road_index_B] distance_to_road_A = jnp.abs(new_position_x - road_A_x) distance_to_road_B = jnp.abs(new_position_x - road_B_x) @@ -539,6 +674,7 @@ def _check_landing_position( @partial(jax.jit, static_argnums=(0,)) def _advance_player_car( self, + level: chex.Array, position_x: chex.Array, position_y: chex.Array, road_index_A: chex.Array, @@ -568,10 +704,13 @@ def _advance_player_car( move_y, move_x, step_size, speed_sign = self._compute_movement_timing(speed, step_counter) # Get slope and intercept for current road - slope, b = self._get_slope_and_intercept_from_indices(current_road, road_index_A, road_index_B) + slope, b = self._get_slope_and_intercept_from_indices(current_road, road_index_A, road_index_B, level) # Determine X direction based on current road segment (for normal movement) - car_direction_x = self._compute_direction_x(current_road, road_index_A, road_index_B) + car_direction_x = self._compute_direction_x(current_road, road_index_A, road_index_B, level) + + corners_a, corners_b = self._get_track_corners_for_level(level) + corners_y = self._get_track_corners_y_for_level(level) position = EntityPosition(x=position_x, y=position_y, width=width, height=height) @@ -603,9 +742,9 @@ def _advance_player_car( # === AIR STEERING / MAGNETISM === # Gradually steer towards the nearest road while in the air to prevent "teleporting" on landing - segment_curr = self._get_road_segment(new_player_y) - road_A_x_curr = self._get_x_on_road(new_player_y, segment_curr, self.consts.FIRST_TRACK_CORNERS_X) - road_B_x_curr = self._get_x_on_road(new_player_y, segment_curr, self.consts.SECOND_TRACK_CORNERS_X) + segment_curr = self._get_road_segment(new_player_y, level) + road_A_x_curr = self._get_x_on_road(new_player_y, segment_curr, corners_a, corners_y) + road_B_x_curr = self._get_x_on_road(new_player_y, segment_curr, corners_b, corners_y) dist_A = jnp.abs(raw_jump_x - road_A_x_curr) dist_B = jnp.abs(raw_jump_x - road_B_x_curr) @@ -637,11 +776,11 @@ def _advance_player_car( # === LANDING LOGIC === # Get the current road segment based on new Y position - segment = self._get_road_segment(new_player_y) + segment = self._get_road_segment(new_player_y, level) # Calculate X positions of both roads at the new Y position - road_A_x = self._get_x_on_road(new_player_y, segment, self.consts.FIRST_TRACK_CORNERS_X) - road_B_x = self._get_x_on_road(new_player_y, segment, self.consts.SECOND_TRACK_CORNERS_X) + road_A_x = self._get_x_on_road(new_player_y, segment, corners_a, corners_y) + road_B_x = self._get_x_on_road(new_player_y, segment, corners_b, corners_y) # Calculate distances to each road dist_to_road_A = jnp.abs(new_player_x - road_A_x) @@ -740,6 +879,7 @@ def _advance_player_car( @partial(jax.jit, static_argnums=(0,)) def _advance_car_core( self, + level: chex.Array, position_x: chex.Array, position_y: chex.Array, road_index_A: chex.Array, @@ -754,8 +894,8 @@ def _advance_car_core( """Simplified car advancement for enemy cars (no jumping/landing logic).""" # Calculate movement timing using helper move_y, move_x, step_size, speed_sign = self._compute_movement_timing(speed, step_counter) - slope, b = self._get_slope_and_intercept_from_indices(current_road, road_index_A, road_index_B) - car_direction_x = self._compute_direction_x(current_road, road_index_A, road_index_B) + slope, b = self._get_slope_and_intercept_from_indices(current_road, road_index_A, road_index_B, level) + car_direction_x = self._compute_direction_x(current_road, road_index_A, road_index_B, level) position = EntityPosition(x=position_x, y=position_y, width=width, height=height) @@ -767,7 +907,7 @@ def _advance_car_core( wrapped_y = -((new_y * -1) % self.consts.TRACK_LENGTH) # Update road segment indices based on new position - segment_from_y = self._get_road_segment(new_y) + segment_from_y = self._get_road_segment(new_y, level) # Update road indices to track the current segment (use jnp.where for branchless execution) next_road_index_A = jnp.where(current_road == 0, segment_from_y, road_index_A) @@ -791,10 +931,12 @@ def _advance_car_core( @partial(jax.jit, static_argnums=(0,)) def _flag_step(self, state: UpNDownState, new_player_y: chex.Array, player_x: chex.Array, current_road: chex.Array) -> Tuple[Flag, chex.Array, chex.Array]: """Update flag collection state and score (vectorized).""" + corners_a, corners_b = self._get_track_corners_for_level(state.level) + corners_y = self._get_track_corners_y_for_level(state.level) # Calculate flag X positions on both roads # _get_x_on_road supports array inputs via advanced indexing - x_road_0 = self._get_x_on_road(state.flags.y, state.flags.road_segment, self.consts.FIRST_TRACK_CORNERS_X) - x_road_1 = self._get_x_on_road(state.flags.y, state.flags.road_segment, self.consts.SECOND_TRACK_CORNERS_X) + x_road_0 = self._get_x_on_road(state.flags.y, state.flags.road_segment, corners_a, corners_y) + x_road_1 = self._get_x_on_road(state.flags.y, state.flags.road_segment, corners_b, corners_y) flag_x = jnp.where(state.flags.road == 0, x_road_0, x_road_1) @@ -873,11 +1015,13 @@ def _collectible_step(self, state: UpNDownState, new_player_y: chex.Array, playe type_id_spawn = jnp.clip(type_id_spawn, 0, 3).astype(jnp.int32) # Calculate X position on road (use jnp.where for branchless) - segment_spawn = self._get_road_segment(y_spawn) + segment_spawn = self._get_road_segment(y_spawn, state.level) + corners_a, corners_b = self._get_track_corners_for_level(state.level) + corners_y = self._get_track_corners_y_for_level(state.level) x_spawn = jnp.where( road_spawn == 0, - self._get_x_on_road(y_spawn, segment_spawn, self.consts.FIRST_TRACK_CORNERS_X), - self._get_x_on_road(y_spawn, segment_spawn, self.consts.SECOND_TRACK_CORNERS_X), + self._get_x_on_road(y_spawn, segment_spawn, corners_a, corners_y), + self._get_x_on_road(y_spawn, segment_spawn, corners_b, corners_y), ) # Create mask for which collectibles to update @@ -993,6 +1137,7 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: state.player_car.current_road, state.player_car.road_index_A, state.player_car.road_index_B, + state.level, ) # Calculate progress through steep segment (0.0 = bottom, 1.0 = top) @@ -1001,10 +1146,15 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: state.player_car.current_road, state.player_car.road_index_A, state.player_car.road_index_B, + state.level, ) # Determine if player is on steep road going up (not jumping) - on_steep_not_jumping = jnp.logical_and(is_on_steep_road, jnp.logical_not(state.is_jumping)) + use_steep_mechanics = state.level == 0 + on_steep_not_jumping = jnp.logical_and( + use_steep_mechanics, + jnp.logical_and(is_on_steep_road, jnp.logical_not(state.is_jumping)), + ) # Start with current speed player_speed = state.player_car.speed @@ -1121,26 +1271,28 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: state.player_car.road_index_A, state.player_car.road_index_B, ) + corners_a, corners_b = self._get_track_corners_for_level(state.level) # Get corner coordinates for the current segment # Segment goes from corner[road_index] to corner[road_index+1] # Use jnp.where for branchless execution start_x = jnp.where( state.player_car.current_road == 0, - self.consts.FIRST_TRACK_CORNERS_X[road_index], - self.consts.SECOND_TRACK_CORNERS_X[road_index], + corners_a[road_index], + corners_b[road_index], ) end_x = jnp.where( state.player_car.current_road == 0, - self.consts.FIRST_TRACK_CORNERS_X[road_index + 1], - self.consts.SECOND_TRACK_CORNERS_X[road_index + 1], + corners_a[road_index + 1], + corners_b[road_index + 1], ) - start_y = self.consts.TRACK_CORNERS_Y[road_index] + corners_y = self._get_track_corners_y_for_level(state.level) + start_y = corners_y[road_index] end_y = jnp.where( - jnp.equal(self.consts.FIRST_TRACK_CORNERS_X[road_index + 1], self.consts.FIRST_TRACK_CORNERS_X[road_index + 2]), - self.consts.TRACK_CORNERS_Y[road_index + 2], - self.consts.TRACK_CORNERS_Y[road_index + 1], + jnp.equal(corners_a[road_index + 1], corners_a[road_index + 2]), + corners_y[road_index + 2], + corners_y[road_index + 1], ) # Calculate slope: how much X changes per unit Y change @@ -1188,6 +1340,7 @@ def _player_step(self, state: UpNDownState, action: chex.Array) -> UpNDownState: jump_progress = jnp.clip(jump_progress, 0.0, 1.0) updated_player_car = self._advance_player_car( + level=state.level, position_x=state.player_car.position.x, position_y=state.player_car.position.y, road_index_A=state.player_car.road_index_A, @@ -1269,6 +1422,7 @@ def _flag_step_main(self, state: UpNDownState) -> UpNDownState: @partial(jax.jit, static_argnums=(0,)) def _level_progression_step(self, state: UpNDownState) -> UpNDownState: """Handle level completion: award bonus and reset flags.""" + # Temporary test shortcut: progress after collecting any single flag. all_flags_collected = jnp.all(state.flags_collected_mask) bonus = jnp.where(all_flags_collected, self.consts.ALL_FLAGS_BONUS, 0) @@ -1334,9 +1488,11 @@ def _initialize_collectibles(self) -> Collectible: ) @partial(jax.jit, static_argnums=(0,)) - def _initialize_enemies(self, key: chex.Array, player_start_y: chex.Array) -> EnemyCars: + def _initialize_enemies(self, key: chex.Array, player_start_y: chex.Array, level: chex.Array) -> EnemyCars: """Seed the initial set of visible enemies around the player.""" key_init, key_type, key_road, key_speed, key_sign = jax.random.split(key, 5) + corners_a, corners_b = self._get_track_corners_for_level(level) + corners_y = self._get_track_corners_y_for_level(level) offsets = self.consts.INITIAL_ENEMY_BASE_OFFSET + self.consts.INITIAL_ENEMY_GAP * jnp.arange(self.consts.INITIAL_ENEMY_COUNT) spawn_signs = jax.random.choice(key_sign, jnp.array([-1.0, 1.0]), shape=(self.consts.INITIAL_ENEMY_COUNT,)) @@ -1344,12 +1500,12 @@ def _initialize_enemies(self, key: chex.Array, player_start_y: chex.Array) -> En init_y = -(((raw_spawn_y) * -1) % self.consts.TRACK_LENGTH) init_road = jax.random.randint(key_road, shape=(self.consts.INITIAL_ENEMY_COUNT,), minval=0, maxval=2) - init_segments = jax.vmap(self._get_road_segment)(init_y) + init_segments = jax.vmap(lambda y: self._get_road_segment(y, level))(init_y) init_x = jax.vmap(lambda y, seg, road: jax.lax.cond( road == 0, - lambda _: self._get_x_on_road(y, seg, self.consts.FIRST_TRACK_CORNERS_X), - lambda _: self._get_x_on_road(y, seg, self.consts.SECOND_TRACK_CORNERS_X), + lambda _: self._get_x_on_road(y, seg, corners_a, corners_y), + lambda _: self._get_x_on_road(y, seg, corners_b, corners_y), operand=None, ))(init_y, init_segments, init_road) @@ -1361,8 +1517,8 @@ def _initialize_enemies(self, key: chex.Array, player_start_y: chex.Array) -> En def init_direction(seg, road): raw = jax.lax.cond( road == 0, - lambda _: self.consts.FIRST_TRACK_CORNERS_X[seg+1] - self.consts.FIRST_TRACK_CORNERS_X[seg], - lambda _: self.consts.SECOND_TRACK_CORNERS_X[seg+1] - self.consts.SECOND_TRACK_CORNERS_X[seg], + lambda _: corners_a[seg + 1] - corners_a[seg], + lambda _: corners_b[seg + 1] - corners_b[seg], operand=None, ) return jax.lax.cond(raw > 0, lambda _: 1, lambda _: -1, operand=None) @@ -1453,11 +1609,13 @@ def _enemy_step_main(self, state: UpNDownState) -> UpNDownState: spawn_y = -(((raw_spawn_y) * -1) % self.consts.TRACK_LENGTH) spawn_road = self._sample_enemy_spawn_road(key_spawn_direction) - segment_spawn = self._get_road_segment(spawn_y) + segment_spawn = self._get_road_segment(spawn_y, state.level) + corners_a, corners_b = self._get_track_corners_for_level(state.level) + corners_y = self._get_track_corners_y_for_level(state.level) spawn_x = jnp.where( spawn_road == 0, - self._get_x_on_road(spawn_y, segment_spawn, self.consts.FIRST_TRACK_CORNERS_X), - self._get_x_on_road(spawn_y, segment_spawn, self.consts.SECOND_TRACK_CORNERS_X), + self._get_x_on_road(spawn_y, segment_spawn, corners_a, corners_y), + self._get_x_on_road(spawn_y, segment_spawn, corners_b, corners_y), ) spawn_speed_mag = jax.random.randint(key_spawn_speed, shape=(), minval=self.consts.ENEMY_SPEED_MIN, maxval=self.consts.ENEMY_SPEED_MAX + 1) @@ -1467,8 +1625,8 @@ def _enemy_step_main(self, state: UpNDownState) -> UpNDownState: direction_raw = jnp.where( spawn_road == 0, - self.consts.FIRST_TRACK_CORNERS_X[segment_spawn+1] - self.consts.FIRST_TRACK_CORNERS_X[segment_spawn], - self.consts.SECOND_TRACK_CORNERS_X[segment_spawn+1] - self.consts.SECOND_TRACK_CORNERS_X[segment_spawn], + self.consts.FIRST_TRACK_CORNERS_X[state.level, segment_spawn + 1] - self.consts.FIRST_TRACK_CORNERS_X[state.level, segment_spawn], + self.consts.SECOND_TRACK_CORNERS_X[state.level, segment_spawn + 1] - self.consts.SECOND_TRACK_CORNERS_X[state.level, segment_spawn], ) spawn_direction_x = jnp.where(direction_raw > 0, 1, -1) @@ -1490,6 +1648,7 @@ def _enemy_step_main(self, state: UpNDownState) -> UpNDownState: enemy_speed = jnp.where(jnp.logical_and(enemy_active, flip_mask), -enemy_speed, enemy_speed) move_fn = lambda px, py, ra, rb, cr, sp, tp: self._advance_car_core( + level=state.level, position_x=px, position_y=py, road_index_A=ra, @@ -1525,9 +1684,12 @@ def _enemy_step_main(self, state: UpNDownState) -> UpNDownState: wrapped_dist = jnp.minimum(jnp.abs(delta_y), self.consts.TRACK_LENGTH - jnp.abs(delta_y)) far_mask = wrapped_dist > self.consts.ENEMY_DESPAWN_DISTANCE age_mask = enemy_age > self.consts.ENEMY_MAX_AGE + hazard_mask = self._is_level_hazard_position(state.level, moved_position_y) despawn_mask = jnp.logical_and(enemy_active, jnp.logical_or(far_mask, age_mask)) - final_active = jnp.logical_and(enemy_active, jnp.logical_not(despawn_mask)) - enemy_age = jnp.where(despawn_mask, jnp.zeros_like(enemy_age), enemy_age) + hazard_despawn_mask = jnp.logical_and(enemy_active, hazard_mask) + total_despawn_mask = jnp.logical_or(despawn_mask, hazard_despawn_mask) + final_active = jnp.logical_and(enemy_active, jnp.logical_not(total_despawn_mask)) + enemy_age = jnp.where(total_despawn_mask, jnp.zeros_like(enemy_age), enemy_age) next_enemy_cars = EnemyCars( position=EntityPosition( @@ -1559,18 +1721,10 @@ def _respawn_after_collision(self, state: UpNDownState, new_lives: chex.Array) - """Respawn the player on a random road while preserving score and flags.""" rng_key, road_key, enemy_key = jax.random.split(state.rng_key, 3) - player_start_y = jnp.array(0.0) - start_segment = jnp.array(0, dtype=jnp.int32) respawn_road = jax.random.randint(road_key, shape=(), minval=0, maxval=2) + start_segment, player_start_y, start_x = self._get_spawn_position_for_level(state.level, respawn_road) - start_x = jax.lax.cond( - respawn_road == 0, - lambda _: self._get_x_on_road(player_start_y, start_segment, self.consts.FIRST_TRACK_CORNERS_X), - lambda _: self._get_x_on_road(player_start_y, start_segment, self.consts.SECOND_TRACK_CORNERS_X), - operand=None, - ) - - enemy_cars = self._initialize_enemies(enemy_key, player_start_y) + enemy_cars = self._initialize_enemies(enemy_key, player_start_y, state.level) collectibles = self._initialize_collectibles() player_car = Car( @@ -1601,6 +1755,7 @@ def _respawn_after_collision(self, state: UpNDownState, new_lives: chex.Array) - player_car=player_car, step_counter=state.step_counter, round_started=jnp.array(False), + level=state.level, movement_steps=jnp.array(0), steep_road_timer=jnp.array(0, dtype=jnp.int32), jump_slope=jnp.array(0.0, dtype=jnp.float32), @@ -1629,6 +1784,7 @@ def _enemy_collision_step_main(self, state: UpNDownState) -> UpNDownState: without clearing score or collected flags. - Landing collisions use a larger distance and are road-independent (for crossings). """ + player_x = state.player_car.position.x player_y = state.player_car.position.y @@ -1666,6 +1822,11 @@ def _enemy_collision_step_main(self, state: UpNDownState) -> UpNDownState: jnp.logical_and(jnp.logical_not(state.is_jumping), jnp.logical_not(is_invincible)) ) + level_three_grounded_hazard = jnp.logical_and( + jnp.logical_not(state.is_jumping), + self._is_level_hazard_position(state.level, player_y), + ) + def handle_late_jump(): hits = collision_mask.astype(jnp.int32) bonus = jnp.sum(hits) * self.consts.LATE_JUMP_ENEMY_SCORE @@ -1697,8 +1858,8 @@ def handle_ground_collision(): player_car=dead_car, ) - # Ground collision causes death (landing is now protected by invincibility) - any_fatal_collision = grounded_collision + # Ground collision or level-specific grounded hazard causes death. + any_fatal_collision = jnp.logical_or(grounded_collision, level_three_grounded_hazard) return jax.lax.cond( late_jump_collision, @@ -1726,6 +1887,9 @@ def _passive_score_step_main(self, state: UpNDownState) -> UpNDownState: @partial(jax.jit, static_argnums=(0,)) def _reset_jit(self, key: chex.PRNGKey) -> Tuple[UpNDownObservation, UpNDownState]: rng_key, flag_key, enemy_key = jax.random.split(key, 3) + initial_level = jnp.int32(0) + start_road = jnp.int32(jax.random.randint(rng_key, shape=(), minval=0, maxval=2)) + start_segment, player_start_y, player_start_x = self._get_spawn_position_for_level(initial_level, start_road) # Evenly spread flags along the track with small jitter base_y = jnp.linspace(-900.0, -100.0, self.consts.NUM_FLAGS) @@ -1736,7 +1900,7 @@ def _reset_jit(self, key: chex.PRNGKey) -> Tuple[UpNDownObservation, UpNDownStat flag_roads = jnp.arange(self.consts.NUM_FLAGS) % 2 # Calculate which road segment each flag is on based on Y position - flag_segments = jax.vmap(self._get_road_segment)(flag_y_offsets) + flag_segments = jax.vmap(lambda y: self._get_road_segment(y, initial_level))(flag_y_offsets) # Each flag color index corresponds to its position (0-7) flag_color_indices = jnp.arange(self.consts.NUM_FLAGS) @@ -1753,8 +1917,7 @@ def _reset_jit(self, key: chex.PRNGKey) -> Tuple[UpNDownObservation, UpNDownStat collectibles = self._initialize_collectibles() # Seed initial visible enemies spaced around the player - player_start_y = jnp.array(0.0) - enemy_cars = self._initialize_enemies(enemy_key, player_start_y) + enemy_cars = self._initialize_enemies(enemy_key, player_start_y, initial_level) state = UpNDownState( score=0, @@ -1768,21 +1931,22 @@ def _reset_jit(self, key: chex.PRNGKey) -> Tuple[UpNDownObservation, UpNDownStat is_on_road=True, player_car=Car( position=EntityPosition( - x=jnp.asarray(30, dtype=jnp.float32), - y=jnp.asarray(0, dtype=jnp.float32), + x=jnp.asarray(player_start_x, dtype=jnp.float32), + y=jnp.asarray(player_start_y, dtype=jnp.float32), width=self.consts.PLAYER_SIZE[0], height=self.consts.PLAYER_SIZE[1], ), speed=0, direction_x=0, - current_road=0, - road_index_A=0, - road_index_B=0, + current_road=start_road, + road_index_A=start_segment, + road_index_B=start_segment, type=0, ), step_counter=jnp.array(0), rng_key=rng_key, round_started=jnp.array(False), + level=initial_level, movement_steps=jnp.array(0), steep_road_timer=jnp.array(0, dtype=jnp.int32), jump_slope=jnp.array(0.0, dtype=jnp.float32), @@ -1892,6 +2056,7 @@ def _get_observation(self, state: UpNDownState) -> UpNDownObservation: state.player_car.current_road, state.player_car.road_index_A, state.player_car.road_index_B, + state.level, ) return UpNDownObservation( @@ -2121,7 +2286,7 @@ def __init__(self, consts: UpNDownConstants = None, config: render_utils.Rendere blackout_square = self._createBackgroundSprite(self.consts.FLAG_BLACKOUT_SIZE) # Build asset config locally (matches other games' pattern) - asset_config, road_files = self._get_asset_config(background, top_block, bottom_block, temp_pointer, blackout_square) + asset_config, level_background_files = self._get_asset_config(background, top_block, bottom_block, temp_pointer, blackout_square) sprite_path = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/" ( @@ -2131,13 +2296,91 @@ def __init__(self, consts: UpNDownConstants = None, config: render_utils.Rendere self.COLOR_TO_ID, self.FLIP_OFFSETS ) = self.jr.load_and_setup_assets(asset_config, sprite_path) - self.road_sizes, self.complete_road_size = self._get_road_sprite_sizes(road_files) self.view_height = self.config.game_dimensions[0] - # Precompute offsets so repeated road tiles can wrap seamlessly without gaps. - road_cycle = max(1, self.complete_road_size) - repeats = max(1, int(-(-self.view_height // road_cycle)) + 2) # Ceiling division trick - self._road_tile_offsets = jnp.arange(-repeats, repeats + 1, dtype=jnp.int32) * jnp.int32(self.complete_road_size) - self._num_road_tiles = int(self._road_tile_offsets.shape[0]) + + level_masks_raw = [ + self.SHAPE_MASKS["background_level_1"], + self.SHAPE_MASKS["background_level_2"], + self.SHAPE_MASKS["background_level_3"], + ] + level_sizes = [] + level_cycles = [] + for files in level_background_files: + sizes, cycle = self._get_group_sprite_sizes(files) + level_sizes.append(sizes) + level_cycles.append(cycle) + + self._max_background_segments = max(mask.shape[0] for mask in level_masks_raw) + self._max_background_h = max(mask.shape[1] for mask in level_masks_raw) + self._max_background_w = max(mask.shape[2] for mask in level_masks_raw) + + padded_masks = [] + padded_sizes = [] + padded_content_heights = [] + padded_top_trims = [] + padded_anchor_offsets = [] + level_counts = [] + level_content_cycles = [] + for idx in range(self.consts.LEVEL_COUNT): + mask = level_masks_raw[idx] + count = int(mask.shape[0]) + pad_n = self._max_background_segments - count + pad_h = self._max_background_h - mask.shape[1] + pad_w = self._max_background_w - mask.shape[2] + padded_mask = jnp.pad(mask, ((0, pad_n), (0, pad_h), (0, pad_w)), constant_values=self.jr.TRANSPARENT_ID) + padded_masks.append(padded_mask) + + sizes = level_sizes[idx] + padded_sizes.append(sizes + [0] * (self._max_background_segments - len(sizes))) + + # Use per-sprite opaque bounds to stack sections robustly when dimensions differ across levels. + mask_np = np.asarray(mask) + content_heights: list[int] = [] + top_trims: list[int] = [] + for seg_idx in range(count): + seg = mask_np[seg_idx] + opaque_rows = np.any(seg != self.jr.TRANSPARENT_ID, axis=1) + if np.any(opaque_rows): + first_row = int(np.argmax(opaque_rows)) + last_row = int(len(opaque_rows) - 1 - np.argmax(opaque_rows[::-1])) + top_trims.append(first_row) + content_heights.append(last_row - first_row + 1) + else: + top_trims.append(0) + content_heights.append(0) + + padded_content_heights.append(content_heights + [0] * (self._max_background_segments - len(content_heights))) + padded_top_trims.append(top_trims + [0] * (self._max_background_segments - len(top_trims))) + + # Anchor sections so each next section sits directly above the previous one. + # This matches the map moving downward on screen. + anchor_offsets: list[int] = [0] + for seg_idx in range(1, count): + anchor_offsets.append(anchor_offsets[-1] - content_heights[seg_idx]) + padded_anchor_offsets.append(anchor_offsets + [0] * (self._max_background_segments - len(anchor_offsets))) + + level_content_cycles.append(int(sum(content_heights))) + level_counts.append(count) + + self.level_background_masks = jnp.stack(padded_masks, axis=0) + self.level_background_sizes = jnp.array(padded_sizes, dtype=jnp.int32) + self.level_background_content_heights = jnp.array(padded_content_heights, dtype=jnp.int32) + self.level_background_top_trims = jnp.array(padded_top_trims, dtype=jnp.int32) + self.level_background_anchor_offsets = jnp.array(padded_anchor_offsets, dtype=jnp.int32) + self.level_background_counts = jnp.array(level_counts, dtype=jnp.int32) + self.level_background_cycle_sizes = jnp.array(level_content_cycles, dtype=jnp.int32) + max_background_cycle = max(1, int(max(level_content_cycles))) + bg_repeats = max(1, int(-(-self.view_height // max_background_cycle)) + 2) + self._background_tile_indices = jnp.arange(-bg_repeats, bg_repeats + 1, dtype=jnp.int32) + self._num_background_tiles = int(self._background_tile_indices.shape[0]) + self._background_segment_indices = jnp.tile( + jnp.arange(self._max_background_segments, dtype=jnp.int32), + self._num_background_tiles, + ) + self._background_draw_order_indices = jnp.arange( + self._num_background_tiles * self._max_background_segments, + dtype=jnp.int32, + ) self.enemy_sprite_names = { self.consts.ENEMY_TYPE_CAMERO: "camero_left", @@ -2202,27 +2445,52 @@ def _createBackgroundSprite(self, dimensions: Tuple[int, int]) -> jnp.ndarray: sprite = jnp.tile(jnp.array(color, dtype=jnp.uint8), (*shape[:2], 1)) return sprite - def _get_road_sprite_sizes(self, road_files: list[str]) -> list: - """Returns the sizes of the road sprites limited to the configured files.""" - road_dir = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/roads" + def _extract_numeric_suffix(self, filename: str, prefix: str) -> int: + stem = os.path.splitext(filename)[0] + match = re.search(rf"^{re.escape(prefix)}(\d+)$", stem) + return int(match.group(1)) if match else 10**9 + + def _sorted_background_files(self, background_dir: str, prefix: str) -> list[str]: + files = [file for file in os.listdir(background_dir) if file.endswith(".npy") and file.startswith(prefix)] + return sorted(files, key=lambda file: self._extract_numeric_suffix(file, prefix)) + + def _get_group_sprite_sizes(self, relative_files: list[str]) -> Tuple[list[int], int]: + """Returns sprite heights and total height for a configured file group.""" + sprite_root = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down" sizes = [] - for file in road_files: - sprite_name = os.path.basename(file) - sprite = jnp.load(f"{road_dir}/{sprite_name}") + for relative_file in relative_files: + sprite = jnp.load(f"{sprite_root}/{relative_file}") sizes.append(sprite.shape[0]) complete_size = int(sum(sizes)) return sizes, complete_size - def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.ndarray, bottomBlockSprite: jnp.ndarray, tempPointer: jnp.ndarray, blackoutSquare: jnp.ndarray) -> tuple[list, list[str]]: + def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.ndarray, bottomBlockSprite: jnp.ndarray, tempPointer: jnp.ndarray, blackoutSquare: jnp.ndarray) -> tuple[list, list[list[str]]]: """Return asset manifest and ordered road files (renderer-local like other games).""" road_dir = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/roads" + background_dir = f"{os.path.dirname(os.path.abspath(__file__))}/sprites/up_n_down/background" road_files = sorted( file for file in os.listdir(road_dir) if file.endswith(".npy") ) roads = [f"roads/{file}" for file in road_files] + + # Keep level 1 on the original road sections to preserve legacy alignment. + level_1_background = roads + level_2_background = [f"background/{file}" for file in self._sorted_background_files(background_dir, "backround_lvl2_")] + level_3_background = [f"background/{file}" for file in self._sorted_background_files(background_dir, "backround_lvl3_")] + + if not level_1_background: + raise ValueError("Missing level 1 background sprite files") + if not level_2_background: + level_2_background = level_1_background + if not level_3_background: + level_3_background = level_1_background + return [ {'name': 'background', 'type': 'background', 'data': backgroundSprite}, + {'name': 'background_level_1', 'type': 'group', 'files': level_1_background}, + {'name': 'background_level_2', 'type': 'group', 'files': level_2_background}, + {'name': 'background_level_3', 'type': 'group', 'files': level_3_background}, {'name': 'road', 'type': 'group', 'files': roads}, {'name': 'player', 'type': 'single', 'file': 'player_car.npy'}, {'name': 'camero_left', 'type': 'single', 'file': 'enemy_cars/camero_left.npy'}, @@ -2242,12 +2510,13 @@ def _get_asset_config(self, backgroundSprite: jnp.ndarray, topBlockSprite: jnp.n {'name': 'ice_cream', 'type': 'single', 'file': 'ice_cream_cone.npy'}, {'name': 'tempPointer', 'type': 'procedural', 'data': tempPointer}, {'name': 'blackout_square', 'type': 'procedural', 'data': blackoutSquare}, - ], roads + ], [level_1_background, level_2_background, level_3_background] - def _get_x_on_road(self, y: chex.Array, road_segment: chex.Array, track_corners_x: chex.Array) -> chex.Array: + def _get_x_on_road(self, y: chex.Array, road_segment: chex.Array, track_corners_x: chex.Array, level: chex.Array) -> chex.Array: """Calculate the X position on a road given a Y coordinate and road segment.""" - y1 = self.consts.TRACK_CORNERS_Y[road_segment] - y2 = self.consts.TRACK_CORNERS_Y[road_segment + 1] + corners_y = self.consts.TRACK_CORNERS_Y[level] + y1 = corners_y[road_segment] + y2 = corners_y[road_segment + 1] x1 = track_corners_x[road_segment] x2 = track_corners_x[road_segment + 1] t = jnp.where(y2 != y1, (y - y1) / (y2 - y1), 0.0) @@ -2278,36 +2547,63 @@ def _jump_arc_offset(self, jump_cooldown: chex.Array, total_duration: chex.Array @partial(jax.jit, static_argnums=(0,)) def render(self, state): raster = self.jr.create_object_raster(self.BACKGROUND) - road_diff = (-state.player_car.position.y) % self.complete_road_size - - # Vectorized road rendering: compute all Y offsets, stamp via vmap, fold overlays. - road_masks = self.SHAPE_MASKS["road"] # shape: (N, H, W) - num_segments = road_masks.shape[0] - - sizes = jnp.array(self.road_sizes, dtype=jnp.int32) - # Offsets: [0, cumsum(sizes[1:])] - offsets = jnp.concatenate([ + level_index = jnp.asarray(state.level, dtype=jnp.int32) + + background_masks = self.level_background_masks[level_index] + background_sizes = self.level_background_sizes[level_index] + background_content_heights = self.level_background_content_heights[level_index] + background_top_trims = self.level_background_top_trims[level_index] + background_anchor_offsets = self.level_background_anchor_offsets[level_index] + background_count = self.level_background_counts[level_index] + + # Keep legacy level-1 alignment behavior; level 2/3 use robust multi-level mapping. + level1_cycle = jnp.maximum(jnp.sum(background_sizes), 1) + level1_diff = ((-state.player_car.position.y) % level1_cycle).astype(jnp.int32) + level1_offsets = jnp.concatenate([ jnp.array([0], dtype=jnp.int32), - jnp.cumsum(sizes[1:], axis=0) + jnp.cumsum(background_sizes[1:], axis=0), ], axis=0) + background_cycle = jnp.maximum(self.level_background_cycle_sizes[level_index], 1) + track_progress = (-state.player_car.position.y) % jnp.asarray(self.consts.TRACK_LENGTH, dtype=jnp.float32) + background_diff = jnp.floor( + track_progress * background_cycle.astype(jnp.float32) / jnp.asarray(self.consts.TRACK_LENGTH, dtype=jnp.float32) + ).astype(jnp.int32) + base_y = jnp.asarray(self.consts.INITIAL_ROAD_POS_Y, dtype=jnp.int32) - y_positions = base_y + (road_diff.astype(jnp.int32)) - offsets + level1_y_positions = base_y + level1_diff - level1_offsets + # Align based on opaque bounds and precomputed anchor ordering. + levelN_y_positions = base_y + background_diff + background_anchor_offsets - background_top_trims + is_level1 = level_index == jnp.int32(0) + background_y_positions = jnp.where(is_level1, level1_y_positions, levelN_y_positions) + + draw_sizes = jnp.where(is_level1, background_sizes, background_content_heights) + draw_cycle = jnp.where(is_level1, level1_cycle.astype(jnp.int32), background_cycle.astype(jnp.int32)) + + background_tile_offsets = self._background_tile_indices * draw_cycle + background_tile_count = self._num_background_tiles - tile_offsets = self._road_tile_offsets - tile_count = self._num_road_tiles - tiled_y = (y_positions[None, :] + tile_offsets[:, None]).reshape(-1) - tiled_masks = jnp.tile(road_masks, (tile_count, 1, 1)) - tiled_sizes = jnp.tile(sizes, tile_count) + tiled_background_y = (background_y_positions[None, :] + background_tile_offsets[:, None]).reshape(-1) + tiled_background_masks = jnp.tile(background_masks, (background_tile_count, 1, 1)) + tiled_background_sizes = jnp.tile(draw_sizes, background_tile_count) + tiled_background_indices = self._background_segment_indices - visible = jnp.logical_and( - tiled_y < self.view_height, - (tiled_y + tiled_sizes) > 0 + total_background_segments = background_tile_count * self._max_background_segments + draw_keys = tiled_background_y.astype(jnp.int32) * jnp.int32(total_background_segments + 1) + self._background_draw_order_indices + draw_order = jnp.argsort(draw_keys) + sorted_background_y = tiled_background_y[draw_order] + sorted_background_masks = tiled_background_masks[draw_order] + sorted_background_sizes = tiled_background_sizes[draw_order] + sorted_background_indices = tiled_background_indices[draw_order] + + background_visible = jnp.logical_and( + sorted_background_indices < background_count, + jnp.logical_and(sorted_background_y < self.view_height, (sorted_background_y + sorted_background_sizes) > 0), ) empty_raster = jnp.full_like(self.BACKGROUND, self.jr.TRANSPARENT_ID) - def stamp(y, mask, is_visible): + def stamp_background(y, mask, is_visible): return jax.lax.cond( is_visible, lambda _: self.jr.render_at_clipped(empty_raster, 10, y, mask), @@ -2315,15 +2611,21 @@ def stamp(y, mask, is_visible): operand=None, ) - overlays = jax.vmap(stamp)(tiled_y, tiled_masks, visible) - - total_segments = tile_count * num_segments + background_overlays = jax.vmap(stamp_background)(sorted_background_y, sorted_background_masks, background_visible) - def combine(i, acc): - over = overlays[i] + def combine_background(i, acc): + over = background_overlays[i] return jnp.where(over != self.jr.TRANSPARENT_ID, over, acc) - raster = jax.lax.fori_loop(0, total_segments, combine, raster) + raster = jax.lax.fori_loop(0, total_background_segments, combine_background, raster) + + # The level-specific map sprites are the visible track layer. + # Do not overdraw them with the static road set. + + should_hide_dynamic = jnp.logical_or(state.awaiting_round_start, state.awaiting_respawn) + player_y = state.player_car.position.y + corners_a = self.consts.FIRST_TRACK_CORNERS_X[level_index] + corners_b = self.consts.SECOND_TRACK_CORNERS_X[level_index] def select_enemy_mask(enemy_type: chex.Array, going_left: chex.Array): """Select enemy mask: left masks are base, right masks are horizontally flipped.""" @@ -2345,12 +2647,10 @@ def render_enemy(carry, enemy_idx): enemy_y = enemy_y_arr[enemy_idx] enemy_type = enemy_type_arr[enemy_idx] direction_x = enemy_direction_x_arr[enemy_idx] - screen_y = 105 + (enemy_y - state.player_car.position.y) - # Hide enemies when awaiting round start or awaiting respawn - should_hide = jnp.logical_or(state.awaiting_round_start, state.awaiting_respawn) + screen_y = 105 + (enemy_y - player_y) is_visible = jnp.logical_and( jnp.logical_and(enemy_active, jnp.logical_and(screen_y > 25, screen_y < 195)), - ~should_hide + ~should_hide_dynamic ) enemy_mask = select_enemy_mask(enemy_type, direction_x < 0) @@ -2373,10 +2673,8 @@ def render_enemy(carry, enemy_idx): player_screen_y = jnp.int32(105 - jump_offset) player_mask = self.SHAPE_MASKS["player"] - # Skip rendering player when awaiting respawn OR awaiting round start - should_hide_player = jnp.logical_or(state.awaiting_respawn, state.awaiting_round_start) raster_player = jax.lax.cond( - should_hide_player, + should_hide_dynamic, lambda _: raster_enemies, # Don't render player lambda _: self.jr.render_at_clipped(raster_enemies, state.player_car.position.x, player_screen_y, player_mask), operand=None, @@ -2427,16 +2725,14 @@ def render_flag(carry, flag_idx): flag_x = jax.lax.cond( flag_road == 0, - lambda _: self._get_x_on_road(flag_y, flag_segment, self.consts.FIRST_TRACK_CORNERS_X), - lambda _: self._get_x_on_road(flag_y, flag_segment, self.consts.SECOND_TRACK_CORNERS_X), + lambda _: self._get_x_on_road(flag_y, flag_segment, corners_a, level_index), + lambda _: self._get_x_on_road(flag_y, flag_segment, corners_b, level_index), operand=None, ) - screen_y = 105 + (flag_y - state.player_car.position.y) - # Hide flags when awaiting round start or awaiting respawn - should_hide = jnp.logical_or(state.awaiting_round_start, state.awaiting_respawn) + screen_y = 105 + (flag_y - player_y) is_visible = jnp.logical_and( jnp.logical_and(screen_y > 25, screen_y < 195), - jnp.logical_and(~flag_collected, ~should_hide) + jnp.logical_and(~flag_collected, ~should_hide_dynamic) ) color_id = self.flag_palette_ids[flag_color_idx] colored_flag_mask = jnp.where( @@ -2481,12 +2777,10 @@ def render_collectible(carry, collectible_idx): collectible_active = state.collectibles.active[collectible_idx] collectible_color_idx = state.collectibles.color_idx[collectible_idx] collectible_type_id = state.collectibles.type_id[collectible_idx] - screen_y = 105 + (collectible_y - state.player_car.position.y) - # Hide collectibles when awaiting round start or awaiting respawn - should_hide = jnp.logical_or(state.awaiting_round_start, state.awaiting_respawn) + screen_y = 105 + (collectible_y - player_y) is_visible = jnp.logical_and( jnp.logical_and(screen_y > 25, screen_y < 195), - jnp.logical_and(collectible_active, ~should_hide) + jnp.logical_and(collectible_active, ~should_hide_dynamic) ) def get_sprite_and_mask(type_id): diff --git a/src/jaxatari/games/mods/upndown/upndown_mod_plugins.py b/src/jaxatari/games/mods/upndown/upndown_mod_plugins.py index 5382e0efa..a23f9553f 100644 --- a/src/jaxatari/games/mods/upndown/upndown_mod_plugins.py +++ b/src/jaxatari/games/mods/upndown/upndown_mod_plugins.py @@ -9,7 +9,13 @@ class RemoveStepRoadsMod(JaxAtariInternalModPlugin): @partial(jax.jit, static_argnums=(0,)) - def _is_steep_road_segment(self, current_road: chex.Array, road_index_A: chex.Array, road_index_B: chex.Array) -> chex.Array: + def _is_steep_road_segment( + self, + current_road: chex.Array, + road_index_A: chex.Array, + road_index_B: chex.Array, + level: chex.Array, + ) -> chex.Array: return jnp.array(False) @@ -62,19 +68,8 @@ def _adjust_enemy_spawn_timer(self, state: UpNDownState, spawn_timer: chex.Array target_interval = jnp.maximum(min_interval, decayed_interval) return jnp.minimum(spawn_timer, target_interval) - @partial(jax.jit, static_argnums=(0,)) - def _on_level_completed(self, state: UpNDownState) -> UpNDownState: - return state._replace( - movement_steps=jnp.array(0, dtype=jnp.int32), - enemy_spawn_timer=jnp.array(self._env.consts.ENEMY_SPAWN_INTERVAL_BASE, dtype=jnp.int32), - ) - class TimeDecayCollectibleValueMod(JaxAtariInternalModPlugin): - @partial(jax.jit, static_argnums=(0,)) - def _on_level_completed(self, state: UpNDownState) -> UpNDownState: - return state._replace(movement_steps=jnp.array(0, dtype=jnp.int32)) - @partial(jax.jit, static_argnums=(0,)) def _collectible_score_values(self, state: UpNDownState, collectible_type_ids: chex.Array) -> chex.Array: base_scores = self._env.consts.COLLECTIBLE_SCORES[collectible_type_ids] diff --git a/src/jaxatari/games/sprites/up_n_down/background/backround_lvl2_1.npy b/src/jaxatari/games/sprites/up_n_down/background/backround_lvl2_1.npy new file mode 100644 index 0000000000000000000000000000000000000000..f4ad60cf0895dda70c5e3fcfdae9c863fd898a41 GIT binary patch literal 100996 zcmeHO&uXVd5L~Z5fydmUk{klAf=3S?78DU=@w5at;z5jL)r)^#!Z+9_*k|y?er^r} zq1bO~Iy3!EO@A%?_-&_as=K->A`AZe^><%=^Zk2I{&@1|`OBMEKir*ves=!X?^ov+ zXXiiO-u-rW{mYBnyO%fjKmYRj=T|pxfBx#H>tAo))*oJd`q{JlQ&C`=8+{*qb?-s>>PaVqsDeva#Nfd5n|CD!& z;=iX3W&f0S^YkPNx3YiAyG8NeQ-`vD%DZ`b5`|mYKjqz``0uGh*+1poJUxlRt?Zxj zZc+UA)S>L3@@}4_Poi)u`=`8H6#qx+SlO?YcOU7YD%Z;Xth~Ew9!Bo9vOg>D zKGH)~u9f{+d3V)3jNEHwe^%aoq=%|pEBmwZ?y7khx!216ti1b34^_EV_GjhYRr4@% zua*5-dH0bXs&cLD&&s>2=3(SsEBmwZ?jt=^b+|vi8Sh#ro#R_dT+9Yfc`6Kk|K}YJd25T{HjOd*7M6 z_s%}D_TJmIyg%}NkF32V8@-q#-}lJcdvnwB{>b+|vi6p2^kR;D-y>`9%}vYuBj5MP z+FP>Gi#hUrkF32nH!bgveBUE$Z^?!gllR>I|3dRN+qY@SS>A8^etDbi+hn%#HZ2+O z!sPGfcK-4<+qY@SW8QE3etDbi+hn%#HZ2+O!sPGfcK-4<+qY@SW8QE3e!aKB_bq;` zj6uEDdS&gsyIXIC_}=;I?@_O{R@djbpY>LV_uNf^8}(Xi1r4)yu-*#sv-aL=F6y<` zd+n(>$9gNoD{30K6Y90rBX?Pi%X%xsucp_!FY2|{)_IZe)>|Py(kQ}kM&3G5s`=8j`df+ z3Dd{=tLyjNpWTl2J$KuKo9ScyJ$77m2D_cLuYL!s9>k+Q)<=Bgn%x5X)>{2`Ry~ME zeXNi8$ThnK_N}%0?W}qbkNQ|2@sVqG3+!8K_1jtXARhIxKH?+S>=xMXt!LkI@BMM! zZ2g||d$r;#+1s#Pq&@$ff@0Uak1vTF<_Zbzj6=_wRjgoi`KTs}zUf`XIuA0ymkNH_ttqc@uRgOpSAYBg>@gq<9@8)i^umX)uV5NeAe3gcGi6mkNdHH zFCO2oRFA$5@>y%|+gbNPJnqN(y?A`TQa$=M$Y-s+Z)e>H@wgxB_u}#WO11ZGkl$MO z-qJb`@z(k5J?mTNDdU+|)@R!Ay|rfitZ$u%c*c)?_MY{TZ{?b4Wqqa{-&+g-nc}n%{Tj6|b&2DMsLp<`c{@!?;r&Q0r z70$QT?3PwO#3MiJ?~TWKO11Z`tXgrOy=&{fRy^v#_pFb2<(g?>{n6TR-rn_SJ;;Z8 ztbN2Q*J@gEpS^2uuTT%pV|~Oc*GvoRkJg6s_O3_kK|a)D?IT{f9<7z};Vl>*`=e{? zKHhlLhx1mdnHJVZt)ubCH@f!T-^!2r@cl|P)57|wbu=FNM%Ui^TlrBRzF(o zj>aS3=-PXKD?jSP_bb&*3+tm+hG%`ohwppqFYf+3)jr6NSk|}Vt$fI@R5LBCk6IX> z^%)<&@2znj*2r(|Tk&ij@+sF$3+tmchG%`ohwpo9oQE~?Tl-c#n}>YLHPgcSsEy%S zpYh@Q-Wum&jr`WW70>1&pK{H#u)eo8oR9UhXU|SOAMgE>zy9B|Gw;)ro0hjTwdL%& zxogQs4<`Mc>+j6__vEJK?M!Vsdv2~<^3j92eE#N?pJU#jCpRr`GgrsV)xGa6`RKu9 zwlZ5i`Ki6lY#k?CwH)lx%dExC+8+6-&B$zJwrct4>1DPuTRr)yz0GXJ$yUwhmLAJO z3r7vZ__ME@#hR}zJ?3o-M-9XHv!B_j<)_xm?C55Pc~kSXr8k?O&CmY4Y_{a5hGE{- z-nR4@fA;5Phuyx#-x_8%KbxQZdD*MwrlrTcYvHJ27=QNXWry9q=4(rDHb0x6{r$4n zl9w6=y{cjFdse)a5Bam+FH3l9<{j%>`K);BeB{qMv&AyB)z9qA%MQE!mS?R$M}GFs zinsD1f7bJ|$Gl^G^zQcd_Eh^=@ingd9`a>9FH3lP<{j%>`K);BeB{r1Ue=g*tdHJ} z##{MtpR6-mEJIuU%+9>*;O)0OYW+Fm!#>tHhxHN9YUIm$zw9w@SRcJYJod3>cv36Qw`YCy!iw*G&*rf{&R?niA2x8M An*aa+ literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/background/backround_lvl2_2.npy b/src/jaxatari/games/sprites/up_n_down/background/backround_lvl2_2.npy new file mode 100644 index 0000000000000000000000000000000000000000..ad54fe5ba7d0ded683ee6b0803fc0b804dde8351 GIT binary patch literal 101664 zcmeI4KdW6;5XEC_V`E_>yzZ&IS0oq(v9hpJL=dqG!9)ybB8l4g2L#C%u(7bROyxJc z<{}plKaxFJXU?2Gv)>jj@2oeoX4YDFZW7|tKR*8Kqffqg>DrIiemcK%_tBRR&)+*c z|M`dO=ND(^Up;vE{lnY$Z#{T;=kD?QAKw1v(cP!tfAsb3@9sXuZ(M)p-HWs9H*Z{= zz5V9J*)M1MKQG_B_1a;+_RHYjy;lyN7vC#uR_}i9x3J~Ms&jt!e(LpN%db`E{_OqK zYhlZeRpo#s&jw#e(JTb<;SXX ze)fLq^&2E|tIqw|`>EH$mLIFm`PuuaSH+e;t8VR6 z`FrYBM%Y7<;?|$l4v8C6lTl-Z0-u+(vs$T9> z`FrR0u0pUU66pL$hn>9y+CK9#?BzgNGim-|%y-u=|8VoR@8 zxAv+0z5BiTRlVG&^7rniUKLwG617!?pBO ztFOtx!^9gMC_rLXdL08=gjE5j`Mr^^X}dwpXCm(-faA(x2WfAp10H!mDy?TmRd8;I{A$H zn$feODJqxKTt!R8-jmO$uN5;{YL3d~G4dGh42d_`)jFngKBn&%20 zJsQ_!vgCMnkHB4KvF3S%3|b!6WU}P6W?zNcGK)3Oqi3`xlO<>Ld@O&_*^Q+a@ zWU%Bs&+q-eJ~&*3*D{MW&!cCwCX*#+^n5IT%PiJBTl1^e*JQHfRQK=IzszFIvv+@X z{+djdob3Lh^)0hl^Bk>j@A+#oSaPU$Svx+nJ{qQ96*CX=S+9oaN5%9@KI@}l`c*OW zAfNSWn0{1DzvQz%8m3))>!V@%RWb7*pY`mR`c+JSBhUrJe^h-YLqhb10G4mjw^=g=YR7}6*vpyQ8UllVC@>#Ej=|{!%OFrwPVfs}u z^B|w~?3nshOn>CFJ{qQ96*CX=S+9oaN5%9@KI@}l`c*OWAfNSWn0{1DzvQz%8m3))>!V@%RWb7*pY>{(ek@FX>}P#6Ous5-9^|uL4bzW`>6d)gN5k~1 zV&*|U>(wy*sF;4qXMHqGzba-PMbRm?odXT2Jx9~ILt`K*tI z=~uC7<=t zF#W2Sd63U~HB3J$reE?|9}Uy5ikS!btXISIV`2JZKkK7m`c*OWAfNSWn0{1DzvQz% z8m3))>!V@%RWb7*pY>{(epF1qBhUv$PtNHip z_2xysS5G!R{dsXVf6;opd66HjcklW1=f!*L@4e5P7x}&YY`vfUym)JU+2?ulB0sxd zRUiF%vATb>{+VC<|AYJY-)|i}uhH&>emvP6xtw>^@XT+1{vvOF-dF5>-T3{+;q;1p zAG6M`dS9paEc&ZApJ;vA=g<89|Mlfr@72twozJr}TYhim*UoR|)6QpkOV4_q`L*+# z`Ly$S)=e$Hr{=fhUd3lMFZPd~8U3nw^n5ITYF_NO=2x%NuZpYt_v%;kV!wBPc0T>8 zIJ>`SeQI9pkJi_FKK-h=xBlMy)x6l>TYvO^`c-kXzU=eWyx5=JpI0CKs@S`~I$zC; z{nh}4J*9x;=pCp8cDFK5l?d#?Ys z@0e}%40#_auI@itznUleqxD74rymtZ>#IIr&6EApf1kexsFFMHMa5NoK2N9Wm;KM1 zJHP+m=Z<-)xO$GNU-nn`AFZExsW@6+^nBGX`=j+$pU=EhJpIo=`u zkBTdJ?5R=Bll?ulZ?T(xRJ?_Tk!PrRvcGyZ^rPbH{-gD)d9puRU-W$XQE{}s>hsk+ z*+2a!{qy~|e;zIqym&7v&fvkLLG{ai&+et|jnj@B1FU-irW zXnob^GcOfS|2h7gQ{{b7?~DD}eNwNAv-``gPtBYC)8G8{taSst7ZtZ=XQ-)u*q^~e zWevnNzQdv-``gPtBYC z)1UsC&Yszp`gkuY-cn1%EY%PDBQ#y*nao$ky|YvGvcI?f(f2cN75C0Z)yw|g`bXc- zyj9#gA5|~=d+Q&4Kl4^`?|f9f?C-6A^!?0R#l7=U^|HUW{?Yd{Zx#2>N7c*z-ug%1 z&%9OKJ0Ddq`+MsjeLwS7aqoOoz3lI;fAsy#TgAQeQT4LFxBk)hGjA35&PUbD{@(gW z-_N{N+&dptFZ+AzAALXbR&noqRK4u)t$+0W%v;61^HKG(zqkI;_cL!5_s&Pv%l_W_ zN8iu9RoputRWJK{>mPkT^Hy>1d{n*c@2&r;?PvZfj?PQf!~SS}bAJ2hd-;B-I67Zd z5BsC_&H3}b<_)xOATl3*-c=l_?KoT2!~SS}bAJ1F`F^N4I$u=}`=j;E`SZT!4YY3{ nG9P~4RUE19I9t`j{%C!3e*1R$eyBJ)UsVtLqxH@Cx$o;=r+_vI literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/background/backround_lvl2_3.npy b/src/jaxatari/games/sprites/up_n_down/background/backround_lvl2_3.npy new file mode 100644 index 0000000000000000000000000000000000000000..3743b828c8a0ee0a98afd389e45c8e673c71755c GIT binary patch literal 49376 zcmeI5KWk=H7{u4s#@0qi7ll|vTo6kOJ4Hkbt0lM*7i3owwee56egIp)haWF)ioo?p zcz^HYoafwg*DD0y$;>=6b2bsQ@vqOn{_Kly-gxxOqhC*-US5BDb^7ti>2E)upPrqZ ze)r<)r>l$SPhMO-y}bSYrx!n5U%vYO_4gM)UcSQbesq3za{l;(vy%_sJ3IOP-gFBxAI$mv)29Q-}?Jlf7{>6TmI~P>wfdkzQ47f?Qi8<`_1le zeIN7BzQ47f?Qi8<`>pM7-EaQ2_f_ZH{#L%afA#&={pPRkzxI9G-^#DOuQk7QzxiAH z&F;7Tt^Dl!Tl?+r)Arx(_V+vde*60zejoe&-QVx<`|a;@_A>%c9y zR(#(28&$l?^5)jx$NDd~R(!Vn+4$aXFx$TBKKuI}{`bZGeGb2m{e154cliDG_c{DN z`_I?@KK3(E$?w0P|NZ%EeGmJyvis)W8gF~a+dk@v&A&F!zA3i;cF+7<<83c_+ebaI z`Pb&zH^tW9?wNmP-nJdX+djKt{#YvX12hkRy#yJ33r zTf(;eKA!vT(=5Mddh%=IW%q}CW`Db3dh%Psw*5Yy`|i^$zh`>#YvX12V?ML*F`lKy z{N!iH%WcPe=AL6bOO5%-kH*Wc5Bbc#=?(LfpB?w!Z}<2f)RP~L*(b&9+uP6kv*X_T z?H=EQdh(+&`=pqCd;58RcHDcv-Q#;uPkuCJAH(cNz457@_twVjpJMi>-}qF|d)1iz zQp|qz8=vZVZ*9!}DQ18AjZgKwSB=>(#q3AF@u{Bo*2e6gV)m!s_*Bn()tLQK%zpG6 zpXzyUZOr~DW`FvPPxZW4joB~7>_@-xsh;=N#_XSB_NU+YRL^_WnEg`Be)Jok>UnQ% z%>F56fBKD2^}JV&*)PTHN5Ao@p7++q?4M%xr{DNg&wJIF{Zh<+^c$b*d2emZ{wZdE z`i)QZyjP9c&oKLp)-!)Jw!Ox)?`S>qdt=*UJo}8+Gk-L;y~eZeXg%|LW7}gq`;68z ze>Aqe#qdt=*U zJo}8+Gk-L;y~eZeXg%|LW7}gq`;68ze>Aqe#?Pcb{uZ+x|$ z`6*_<6f=*0OdV)jch z^XNCeTF?C6nEg}CKJ*)3t!I9U*)PS+qu=;yJ@b2G_D?bU&~JRTp7|+ezZ5f%e&eh4 z%V?6KAu4ms?Y>q+{9sK&6+je{*1Z$&llf({^hrC-T3Xs@0SlAKKbtPUp2ysXsbaeUNwY?ESg-P|y1mPyNxc>Vv$CXYbFw zhkD+pcuFZIl?cuFZIl? zcuFZIl?cEdgfC+^$WlHp4f-#le~+$uNu2~ zsAnF3ioryT9t+<$L$(=280IdC%^z`gi%>eY$y+zIWcU`>Xz4zIUH)9;NS{_w4?vf0ysw zr<;fR-uagQtN;Gb{!`q$-`Vr^zJG82>F34y-uH8#i+kr={_CG#>kYj627314e*dZO zne95>%;Vzg>{yyNe~<3xLH+3Nna9PW^B;Y``*~15x*z6o@#y?V-|v1N)Q|3md0ad? z|Izll{Zr3A6i^xpcX68U?DKBl)U!XuQ=ekBkGzZ1 zd}p6``=*}#DW3Wit9|5MoaQ^*yxTAJ>`U>~_r@v@c^CK2zjuGPPwIKU;;HYARUYy# z?wx<{{%)Vt^M1us-y5qu4vD!a7@7_y2?^8VWd&lZt@+qFZuj>2gr|zd- zVdaazbSsx-?OFpxtD(A z73RFc$}i>Be)8@eKz4SN`NvtoEn=?tJe&%Ab6S-TSFe_j6wPlTWeQpZdG= zz4Iu4@+o%jr#{`!dF4+&#cF@*@6PwmLw{oCCBHY!-^VdO`Gsr$KRnep`MqJj7v?9w z@Ooc-BQ4DLLM{1)*ZbldX<@z>YRNBL`@Yzx>Yw~<*u6*T)qU0T^h>eIPoDe8E6jOf z&MTh!(OBgrPweKeo~K`mXS;`acpv%Mu*yrG*uAHEo_;Bw?H=agedK4uDld6r_nzu` z`lWdGdzg>+lHVGu-`9~>*v(JBbe?`GPV=PmDnI?HC%-jTdB`j5=BHmePrnqWdD3~6 zpZ?U7-x{ktAcEMf9lC^ja44<3cLB~m(J5K#c7^&-p$W_oF`w6 z)jskmW`5T%op<*seVQkoPxEnqHD*5glTUH#pU$WGIiKc9=b49C`IE24ZXW7g%=}x= zyZe-W_Ip%5^3|CA(Vu*ZQ~z{6&CmHXPdd*$DP|rQ^Y?8F_xzqfl=-rDa2-^13p`tx-2bALKdzZ9o=()q3PyLpv9&ChwY zk9_q!{Zh=l3Ui;*ljnSj>Ay8j^K;(KUp-I16z}~WH-EK0?Vs~*zUq1UrFiyx+*fAu{5Qat(|cmL7!>ht5hN>4tWr(cTGJnlU8XT$3AVjoIRKAopuiqkyq zJoRV8>hoeBN>4tWr(cTGJnlU8XT$3AVjoIRKAopuiqkyqJoVN1x_!Lv_lVVbT))@l zT7Ulkf5&$Ftk$QWLv&R6$EzpeVY?(fN59q*~eZl1mAtNWtg-uhr(-cyaa?`XKXFZ%7hFWys)_uh}n zU)>k|)c)Cd-cyZdzdy}a-534Re52=iPcDqnRU^i%uy&hy@Cy!yF$)#`Msezy8O8~3zsui&_EZ@8j^mdD=uaPziq zui)G~v*{~3XnD+bZ*`v5?G;>g-mU#AI%s)peb3hWTDMnlTi;W?ucCv*qxzoN{E}|2 ZGdtV8Zl3Ghm}uNQd($V~TxT}z<3B4QLTms4 literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/background/backround_lvl2_5.npy b/src/jaxatari/games/sprites/up_n_down/background/backround_lvl2_5.npy new file mode 100644 index 0000000000000000000000000000000000000000..a1348534223a4305a72f6f1292a0a053f2988bca GIT binary patch literal 63968 zcmeI5y>3%M6ojGW3DO|BGoT<5gd0ipNd&T?O*z^}APZ-rv3R{mzf+;r#T&$@KYT`t#e~bayiSczp8p zn7PP?^TVN`Lvw_^V9syNArn3Ps~sAGat>z`)cDueyw}s{kBb; z-}QTJWwXq0+n(6Bl_Bp}O`Fg4=kI^M*XF1BnUCfZZOr^u`I*Nmk7z&i%}?_)AI*n* zqBcLx&wMl=?upv`G(Ypve7GlS^IPS2(>!=T@~@`ZH>=5mdHSPi_Q`7YWuAUD&AwSp z9?a7pO|wr{voG`Xt7-PlYVu&7{%D$gvYLIFr(aF8Z&s5B^Yk}e^ZB!y{g|g;QL}SK zlLxc(N7L++)$GeW{fe5MGnzb@r9YZxpR8tI=IK||?3~f$!7TmJH2Y*V`!Y|zqGsoe zCJ$!mZ@T95XEpmVPrssO=Zq!~X6cWn*(a;nmwEaXH9KcCc`!?VG|fI)&A!ajuc+BM zqsfC=`lD&~$!hjxo_M6JeZ|Fnr5G@W?$y%SJdpB(d5A_{n0f0M4J8hdivEg z`(`zHFi(Fp%|2PpzRc6Frr9^E$%A?NqiOcZYW8KGel^X$Sxp|y(;rQke>BZLk!C->o_;mWzFAEk%+nuDvrksDFZ1-P zY4*)(@?f6+XqtVpnthq4Urn=bR+9(w^heX|lhy3YJpF2#eY2W8n5RFQW}jBQuKrj5 z5B_cc-O7vk+kXDu@5g?vy7%*L+WlL3F~4d3TAzpgT6OF5jFxZZ#r$afqvvJ6Rvo=x zHNTY?^X9Yq`-7_S$o4ruf3#@1vhO$A>s#0?v!nIB>E|^+{rlI>iwCI<`Zqx{4_uF(R{ow^NBXl)A#w*`!gTSC)%d@X@2IT`FLOE6K$ZU z@AIkmXFi%wv`zEV{LDx5@xII_+CWd==Tq;`d^DeEo93tanUChdPjuGJ zPxCV$&ByyPpJ)R;{fmo>Ma$Lm**4+cEo_$AP3W-BXD~ng`DQ+vPqbb0)BMaw^YOmS zC)z+y-{({B&wMnWXq)Dz`I(RA<9(S=w1J+!&!^s>`Di}THqB4-Gat>z`!b(s13i78 zPrX0$(R`w9nxE!Zg-_3)AyoHn<)`jl!80HI^Ow>0*!(m<^U-{~FY}2u(9`$%)cZ3Z z%_rKX`DuRUqxpDW<`Zq8r|%UL#`D~kTZ}Vw8pXR6enUCfZeFm7H=4U>dkN0Ih%MHx-ch+*X zp3ikX_ieqe`LylQ{4_uF(R`wFZho4d`Di}gm-$2+=;`}>>iwCI<`Zqx{4_uF(R{ow W^NBXl)A#w*`!gTSC)(!d{Qd!9hi4-I literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/background/backround_lvl2_6.npy b/src/jaxatari/games/sprites/up_n_down/background/backround_lvl2_6.npy new file mode 100644 index 0000000000000000000000000000000000000000..ca6da799af5807245f4477f1ddbca5a8a4897996 GIT binary patch literal 67616 zcmeI)zm8ma5ykN(C1=zJoNwl@OS_H*}G5u`S#l%zia+{%Im$>c)hRH&;F!a|J-9UDQy%ri=1=uWpY}`hl0Kbh)i3p{AI;v=`Ra*R`}Lpp)4%2EJga`GU;SwIzRp)q zyxOn-w4eShPv=?nOa1Ccv-fqrdg9f7{ipr(Z+SY;s$c3?KbpV%>PPQB{Nwx&wfo<{ zf2&)#M|xiBN3%C_zWSYS>m5taOZ{l}eR00}o$v3L@5rav_r+TEJKxs(<@50Q-5Q%W zahv~WzcerDyY1KJJ-R;KTfNO^zdY@yf6LQ(rv2Qf^*_4ayouXE7yXKKtcqKmA*t&NJ=jKCQp?`d#?)J)!k2Uwl6IbJ~B}?{oL~Slx7#NiKqRi{ZjsEH5no@@K* zFT20GeEB}?@~PhP#pnKco4?m?pQrUcS9Q6qw_n!p&bRfR`Mg^8d8YohzV*}p&d>Y4 zrRSG)pWk$S>t~)jANzi#=aqDym-Xs$s#kA0%}ajv%c)=8{%Lkz;xs??tLvxUa`ID` zotJd$6R-N&FQ@tH_D{3(5~um8UtK@-mXn{l?7XB~pLo^JemTuow||Q~oKz2)Sme&NgaWX`+%{PL;Z^2O)!x!Qm6 zxqMFhU3{PVIh|+KukE+qJ0!iW*RSO?FZtOor+#(&r`dUl)BM!0uAh3#$xmH&Uec{k zyy|DaoaU?BKh4fdoaU!~b^X*^PJZgL^OA0T;#EKU{oAj+OM5&{dAt>r{3~CfA{~_^6fweDVHjn-4El>Nk z^R1uGll;_MzUROH`}N^Flir69Yd`1j$xr=beEA;CPctOl`o#L(dG`0+>HbOgd04OB za`IDexy^I0@7}-md06kglWu+DUyhqPVd5_N1&a>Wql5YJf zr+MlvPv=YXSLdbvG+#aWC0)Ok^|xPrn$tY>mRIv0Jul@wI!`;#diP1X^&jJpeJ|$E zotJd$6YJM<@>g&D+j{3EZvCxK{z*@M>eFmK{ncC6ujQlXrF=){Y3Et*K1sKJmD4=+ zmZ$Tj`K$9%f10nJ{F1I;%lg}|KFw*KddsVMkDizE9-U`(p8aXRYwIbGddsW%oww?@ z+CQB~zn0y{e)Xd{Y_{nL5$YuSD5 zS3jCl9`%-2_i^5;-)jGK9{pN&AN$px${+h)EdAU()~mOy-<`kz`|b9g)_c#STYqg% zdDL6JcK+S{+k0B?{z+WB|)U%jXOw|&=^V>z|ubUiFsKe$#$w zUec%Yw0_o`H|f?tH>bSnEvNma{nEUoPv>d%$s!U6YD3Zes%k& z*?Ec6{M4Ux{jFEOHBNceTTc5<`=xnFpU%_zS#RE?TmRgg@~XF-_M7%g^O8QDr}eYm zyh*qIxjE%kZ#nHZ?U&{yeLByopZzKSt?OwY^_J8A`lb47>unzETXrA&)pz5xk9y1L z9{Q#FYwK+u>sxjo`_*^jw2ykr=^px}`fKZL9_w3nAN$pJvuM8?_s_7NV@g6&S@X@mT!Ik)qL$etapCWt-p0n`>3~k z>-(?fYwuyb^OJ7Gowk5s>!=hpkv{;T?}@1OFyzxC=Z>vuNx9`2{! zvVOa<`DDLKZ#Q;7^Q*V4-`UuExSx8<`t8QMiScHufIwr{1!DyRrM4U%h4h&c@!u{nT65 zZ#Q;d^Q*V4-`UuExSx8<`t8QMiScHufIwr{1!DyRrM4UwxJJKO1`=_fv0Kzunk< z&9C0FerIFv;eP5Z>$e-buldzm*6(cWJ={;dW&L(z_cgzI%le&-y@&g$x2)f8?7k_# z^(`m=-F5HbKI$#&w;Q{!`PEz2?`-Tn+)uq_{dQybHNSex`kjrvhx@6wtlw_zzUEhN zS--Qf_i#V;mi60>)4uML>XWYD+1UNvPrYURc4PN7zk191osGSR`>D6A-)`)_=2vf7 zzq7ITa6k2y_1lfz*Zk@&>vuNx9`2{!vVOa>`@4^N%le&-y@&g$x2)f8?7rq#Z&|;y zvG;I4^_KP9josJ$>MiScHufIwr{1!DyRrM4U%h4h&c@!u{nT65Z+CWo_fcvuNx9`2{!vVOa<` literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/background/backround_lvl2_7.npy b/src/jaxatari/games/sprites/up_n_down/background/backround_lvl2_7.npy new file mode 100644 index 0000000000000000000000000000000000000000..b970039ca5b51fb3b097b7b7a5c37fa045c5e039 GIT binary patch literal 38784 zcmeI1&1xiN7=_2JSKvnIT`Oc38D?B~0qztLT(}y76HyRH5_RJb2r@U|1;?pHnDAiN z)2H4q)%8_Zph=!S=RNOFDlqi?`^DFvfBDUOPkwvy`~KDS?YB4kPj~x2e!bjZ?DpTi zzWL?m>Zg~lZ(dyw?|*jn=-H^gzdc*qvsHbD4%KhZR_)p1UAKE4 zzGr$H-uB!%y&wH+z8`l@dCj*uX1DZS^dHlsR}Xhhx!1j0&7Izh{#J7=F^9XRyu_V` znV$Xg>?E)Be)OMe?m7B$SK~Qm>Zym_%k-XhO=q>;O;7u+&2RTHeQo!6xR32_`opYq z{Mqhf`kHTMyPLk|+x*{dyazqhQ)XUI%w0J}>c?F=3sSL!MEe4AnRbYJ?pzv=Y%%(g4_lzYC-FnhW${oLPl`g>;Em3qoO-)5LS z-IspuZ#w-wv+YVf<(_Xd%%1K`KleAC{+`)(rJi!nw;5(n_obivn@)evY`apg^3reB zdOzmb{`4DLUPZUw)UB8Ksi$oHPpzl>(9iizH$UsitZ&KeWo*4!Pe1jPnYT1f_o1Kr zm~MX7lUd)A*~{2^vz~tHDKl?robE$E_c7i4tS7U+C9{{Y^=3W&)Kg~O(m36Re(qzs z`B_h9eM@F9W9!X&`l+YPyrprv5B=Q7bn~;G%=(thUdGm&_4HFunR!d&bRYV;kLl)T zJ(=|_nZ1myH|y!Ao-*^6#_2xvb05>q&w4WJTQYkYTW{9WPrb^_pPO6H&%EkBt@A4L zw4TNCsHa@nwMA_@Fa4asboz~3=UKloz4a`XM?K}SUFOizdFkhjrqgfSI?wuz>8)q6 zJnAWr?J|d!&PzXMG@X9q)_K-%Om97lF4~W(|^uvx1yf%sc)OkOF!o~o&IxXyA}16=e}(^pZPgI{nR&p_ivZUPGrK7VYpAt6 z>M2*}v-)%%`k8M!{l?5Qw!DgNy{T7v4Yig>J>}|rWA(#l8tFXrGv9RjjhSa`c@^Dy zQ?K+IYAug?%2Ve%_Oq~b9{SnGboz~%XKZ;D-Fj26^crd{k9x}0`K&&jhkoXpPQNkp zj4iLCTW{)>UPGrK7VYpAt6>M2*}v-)%%`k8M! z{l?5Qw!DgNy{T7v4YkZ8GoO0O%&T&>KkKXWSpD4n#^#xO7VAen<*{9-Xx09#pE5&I zGk3qS+2)?b`cY4LY}YATwLj~p%#hT~-EVBRxo5F{)KebYb&6K)&-y7dBsFvQ8=GzJ zS*#!Rl*e{`s2%!TEbT`>b4{mz<9Gizo!-K3O+Dpw9{SUM^jjYFji*0Lyj}Xcpx>DN zQnq~4*^hqe8*k4RpOHh%XN5lM8@GG5cuTuI^_1CXuEvG4tp*o&Hnv!_UlypE-}|^pjan Krr((L3I7L`Q(d$G literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/background/backround_lvl2_8.npy b/src/jaxatari/games/sprites/up_n_down/background/backround_lvl2_8.npy new file mode 100644 index 0000000000000000000000000000000000000000..02e2d6604d7842d786b69faf9ce621e9d9c8a704 GIT binary patch literal 87680 zcmeI4F>WSB6oi+QJ0LkTNOoXk5<&t3OCS=2h>*n^fdm-YKmr6LdjS%T!5Q{qKkLBK z>XTeG&pqww9@k30yybe`^XhZW1{?UF&%gZai?82#@%xKEj$dBgetUEL@zL?0znvYQ z936jmee>(h<!#+`mpXL`nnos%^KI}kGo%d1qEqpW| z*Uj+L{K7}`Nx#C!b)cu8_fz*Td^8`|&G6Iw!bkH-zrx3Lpr@brQ}-`?G#}T^@YDRl zNApR)!pC)>r=Ryz_b+@jAJ@(B)BM6m^GUzL$915mpZ8PuFMKo~*Uj+L{K7}`Nx#C! zb)cu8_fz*Td^8`|&G6Iw!bkH-zrx3Lpr@brQ}-`?G#}T^@YDRlNApR)!pC)>r=Ryz z_b+@jAJ@(B)BM6m^Kt#6AD)_L_-H=t6Lt7$e&M6}uus(Cr}>4C=EFWwho9ybKAI2v zL>+#bU-)P~>=SkPX@23O`LIva;ivh9kLJTZQHP)A7e1N~`$QdnnqT;6KI{{9_-TIO zqxrB;)ZwT3g^%XLK2e9C<`+Ji5BubFcxs;Eqxqy?;o~~c)6e^<`xic%kLzamX@23O z`J`Xr<2umO&-t^_Ae&M6}q+j9VI?&V4`>FdEKAMl~X837-;iLJaU*Y3A z(9_TRsrwf`noqTx;i-9skLKh07Jiyv_-H=qSNOON^z`$7>i&g~=Ht2r=Ryz_b+@j zAJ@(B)BM6m^GUzL$915mpZ8PuFMKo~*Uj+L{K7}`Nx!z`!+yG-M=$@Ao7aaCKIBo= zBV_2gLzM^Td+s*F-Q-c#Gh}SpvC4zQ~c~o`l^G82l<-z&U^Yz}3JgU0)ep}yP<-z%__uu;Y zb>Mq)z#BVn-fy=193SV;mbbdE%kTW^^Y%9H z_&9%WJebGjcRufXZrbs2{@!`H{LbGyAD7SZalX5McHiZ9{_Oc%_ji1p-}?N~&%6B2 zA3b01{Tv_X_ug;o`(1wLZ@vH4&v$&BzxDpTpYQTJzxRHl-|zT1fAoB<`?>tiZ+-sk z=N%vC&z|4y@A5m}-M@F=@o|3mTptvpz4LMT6q_ZUyMK0n2?I4Vd;Zq_6`Lhq>+?rH zzl4FB89iU`{S=!eUhn<3zJCb=HM8~pTR&g1S>kQIfA8lnVW4Ju?>GAWip>&l^n9)R zEn%Q$TAx4rdBtXlH+z1!{}Kjo&$#>c?q6J799ll%qxtQfkIUE6t-G`C{@MNTY5RRd z;ivh9kLKh4{RltJugZtKd7l2>H1k$9dGI{_*);Q1HS_X3{k>`Ct!nb%dHSXVWgP^O?6g z&-0^cm&f@o56`>#(S7ErYUbnls?YhYHS<(8_viVl&-u|c^Hw$U@qE?i{MMRzs+#-r zeAVat=$d(}n)!IX>T`Z;%{*1j{dvCXbAEKqyj9J7JYV%WzqMwbs^eOh&#&$CzO6O$ zRyFtM`L%uC*IP4BRdYX{ulk%HUB7zu-eJwp=y|LAt;y!@uIlq=pAWy+@@F1Cnvd&$ z_-TIOqxqy?;o~~6#6SJ%-@{`*t?%b!o;lN*@OfOX?tjMNr}>4C=5ybphrWiN<`+Ji z&qKYs|9<#se&M6}-1lhd*YMOl!$KXw0}EuX)CIGlF) MH1q7R)0#Q|2Zt=T6#xJL literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/background/backround_lvl2_9.npy b/src/jaxatari/games/sprites/up_n_down/background/backround_lvl2_9.npy new file mode 100644 index 0000000000000000000000000000000000000000..47bb0db5daabbdb92930c99e0dac128192727d2e GIT binary patch literal 33568 zcmeI0JB|}k5Je$n70wKj2?j94MC0~;I0K}qc+tFNnG ze_t2UsC-?wKJU09{=9ng^7Xs>_kP^_xxU`seAus_t=7N3U#vH)^~c@*+y3gy``!L} zd-(pv)#sb--S;=2uD)*X^ux=`&FbRv$!7KR(Ps5~75Mky@!P+*Q~ulScDE&ee4FY; zKee}B`P=fGox1y5ZR3AVj=zbjp`Y4Yul#L!&Q6{8#T#h5AC+UT)OlaL0d+%ddG<=3 z_r)7%yC0Qfuhe;8ya9DXZF%-eJ^p?1H{kE*&3pacdGssIeNy{wPVrTF@}}RL_xio_ z=vSKiq|Wd}HM0)#Rz_m7jTP_E-5*KkKPE zKXu+$^9Fb?Uft}KpO0_rz{AUvellgATKSo$_Vz3Fv!0s!r9Qqf^73l(RQ1ZwJT?2P ze5s%H)SREX?7qlP@h;8VyyyFWw}(IWv+pf-oo7F#ZI7i!#Yg4IjsC9lUH9|uOTW_G zcWGMjQF-#Bzw3P0{k;3quQc~vnpS*Np1kPqI^T6a@4oaa&3%`q6(8o8#;5sSvN%UK z>zDZX97avg5*gW>n;fE_t=Cebm(xqFoC>N_6df%zdGY-)>wfA(rPn3(*b-mBr`Wc64oZEZ7 z`$oA~op)Yse&Z00wejG4qFk(w{g$SULp0XTE6T;{YUiWgXB?tY^`Cn_%EjvD{{GVa zjYBk+&YyaHl#A6({e6A!XB?uj@BPN^ALU|oWA8uoe0z5^4!!Tt^P*g=Zs>idHqSUj z7;I`&(dHV)BPJFh4g ztE-)ldY^HKM%91r`6w5woBR7q_cspFSUP{|^-(TXH}&`Ry`OQ2#=iF(yML66)s4OX z(DUux(Kz(JL(hwHvAUu6o!UI(5RFrNPuKfIxmaD-`^>GMafrsbz1O>Ml#A7Q=hfyn z4$)W}556bL#p>8^Y1%kMW9_`6T&%8kKI(nOAsSWxx#y!?tZwe_FWuibL}Tgvsn<`8{zK2VcSqyU`X%o~fBNR-tM6XifB)0xKfZqP-}hht@cic&|KazZe*E#1o2Sn{d~);A`%iBEyz&2f z_shRNxEp=FdGqWp@-;epbK~x>(NE;=?q{#zVRXjq{zl$6_d)2_hW)h8bWdvX1Czn|{@-TUgM?|whs{dapWuHWwW)7`&& zU$g74eIGn8JHOgT-o@GekAA+}H}&jC@zh7L+DG2SQGct?yM0s7{uEDr6svvYT^#kd z`n=mW_3Tga)JL(}N8ZIzf2+;A{Zh}q6i#cPkl93e#yJI zI$!mEw=e4HU-8seW965;i>vch?|1v6p8ge2eKl5o$-B5ZU-f>sFY4)E@zhsi<(IsR ztMgUwcl%QM>O9f?Dxb?o{dAvq`%?Pd$N8UE>fT4xzj`k&AN5l`@8(td>O9r^RX;9Y zohQ2A&8zfLKS!HaeJQ?sA1a^XyXVjLuliGbw!iH2RX)XM`|Ezb>QC`I|Nb+xqc@=5 zkK#vXci>b$#dpZ~e|ltJ+)duaJa0Da=AoW>6i@xpvFd}oi)Zi8%|ku&DW3YHW7P+F z7th|Gn}>SlQ#|!Y$EpwVE}p$VHxKp9r+DgD!~gvA`0n(J`A@mG`ONeFZ~Qi%PiN@& zl)L{-e&h4sPpiJjySTb<*Dv+-uXyTL$Eq*#F0Ss^^-DedE1vq*vFeMwi>v!}{Zdc= zil=_#WBv1z_wQo92hNk<`278-K2)FNU99#;dDjp1+^=}*Hy*z?)hBruXZP>=p`QK} zPyNRGeeunlV!n@4?*23RjmPgv^}~LnSnZ4QQGc9|`q}g2_s71YIJ@ttKh9_SUHyFa z9mUywM*VR<+wbb_kG`LMySTa^*Dv+e{*S(& zeY?22AJ;GS)&7sZpMATyx*yjs_0|55wx9jFn0;`b{I#*#N8ZKP_Pg47Zr{|iFU3X~2h z)MsPmhrElk{dYUx?T>oqRXp|8SotOI;_7_W``x~%r+>v$UyYSt@-D8U`Dv-M*-&f5lT@jg?>WF0Rg3y+7I)=d1HX_p3bQ6^{DpKCk*A&%BDK zes-+#kXLy2yxIO#Kji60@zl?bRUYyR&z?8ipX!G^{V1OL*|Ew)K8k1WL;2&+J99D5 z=REnNVU?G>i@W#b=B1wb6;FM4tn!j~@zMH;=I4C(ew2Tem%NMB{@HmqFZIl)cZkiW`-tM*$NvAZ zp1y(A-_vOq4}ISKjDP;rL!F&|Z}&IuZ}&6q=kzT-^m+F;?r--q?&qO5b^5*C-?+cs z&$yq{xAf3ww?Cf2dGg(`%0u48-TH9-x_wa3eTt`kcC7M{ck%4`qyF4JsONsgQ$IUa zdC0qX_WV(QZXeWhzv8K%9jiR#T|9gKs6V$4rH}fVeV+O0m%NL4UN`LKrJi{ePknc+ z@{)J)(fWzz=Y02mlz)|%yo=TT*?BiF^~|Su>bql=m%NLQ*3WABuYG^)kLQzjG0&R~ zyLqT*9>r6CbgcRy@8a3}bMsKoe2S<3=vehZ-o>-`=jNfF`4mt6(Xr}-yo+b=&&@+U z^C_PC)v)S^yo;;*bp29K|B9!6b*%a#@8aryUBA@Rzv8K19jm^`ySTbv*Dv+-uXyTL z$Eq*#F0Ss^^-DedE1vp|&;R>7)gO5mXZP*;p`QK}PyNQ{?}>e=KFPb7`>u^$Kh)EY z;;G+wzc0R-Q_S~q%H4k^zw!7zu^-hZc^7lvwXy4mdiqg3^&5NdE&B7+IG^1=_eF8G z-__5L-Vf)q`{ceT&i1?d`O*8~e0HDQ7sc6rS3f^`Kb+6*ll!7L+wVUA{NLARzpvT- z?e}#4p4i9i_cyyg_PO!-dz=0KX7{(>)A@U1AG6=z?Ecv2>bU!Tb@zYt_t*V?y8GYn z>+#*_4y*eadAo$G_J8#Km)yZfQ{B(V+a+AJ|D*4}V8JvF5#;EAASENcQDda z_cQW#30Lj^==(3ZgOR4XpOLpqxN84L-+##+j5O8#jJ#dKRr^2M{!2O^Y1jv~e-*-sgGi{kGzYc{#KuN`=*}#DW3W$ zR{O}iIO=cpdAD!s*`MO6k7Bityo;m$R-bqKrk?#Np86TmUVw{Pm%pW>;H zVzrOFi=+NlpLhGFp8YAF`Y2ZW$h$b|Z}oY%Z|d2f;;D~fwU4}uqyAQ(cl)NE{VAUM zC|3K(yEy7^^?A2%>e-*-sgGi{kGzYc{#KuN`=*}#DW3W$R{O}iIO=b;dADEc*_Yy} z&&J9Rc^7B4iKjdAU?Z5l^ZlBaMzv8LS#>x+Q7iatLe!kl$^~|q$>a(%(L*B*N{=1*=_DMbS zE1vpnto)F7akl^N=evDU&-{v~J{v1Pu`p89O8{E&BXw*T(uyM0p6{EDYO8!JELU7YQ|+xc#P z)HAQ*sjtS$FL@VN=d0fD_C-DYE1vplto)L9adp1x{cc~>)4$@Wug1zRc^6mbtKRSS zMLqp1p89I6{E~NZb-wETZeP^Xzv8K{#>y{w7gy)2-tYEBJ^d@5`f9BFl6P@+zUuvM zU)0mT;;FC3$}f2rSLdtV@AgGK{VSgOYOMT{kK*dQ+54mYaGrh@Pkl93e#u91b>8g# z(SA5jzlx{68Y{o#qqsV6_Wo!;oTp#KQ(ujhU-D60oi}@bv>(pXui~k%#>y}GD6Y<% zz5m+%@VspQ+2^Y~{)JoU3-$kNR`_pq~2`PyOszyyRWny)QQ}^~|q$>bql=m%NLQ)=xA)=ezf#{Hwg= tU99%c&bxW3XFkPK-yN&Gx6@~H8aux0jXi@|*5~8G_10f^|Dl{U4D53<-01ba`f|e?`k&7FW;z&n7 zS@UF{@BKa>A3M^>=kZ#5t-as4vi#rQ{QlR!{lgDF`Nt>!y#M^|`#-(A|Ha+?zyAK& z{j0nCKY#V^Z|~lG`Nyx`eg5{l@Biw}7w_MG`~CNSdGpt|-|~-s{`%G3XRm+y>h5Ph zes%Zny9@sN;qU(Q>4z=R9{Q~R z*7%{D&%fur=h`Ko&AT?fW~V`mF!f_@SH6zvsQ@+9jXOyEeY$rt=x^IoBTg ztpC>dp_|XY=e_6JC7;c^HooMh^BM0s*B<(;|JL}So6o=Jz31AB&*mS^TW@e*^VEx- zFRr|;diASM`)Xc&%F}&T`=|Trcc0Sr%gz@&znZIGecE61>QkQXv)Vu1SHJs|u3vV( z*!k65{p!>HnpdClbf4A!>Aw2ir*!?Y^Tp1u=IU3U_D|>aPkFlE$^U(SZ#LaWzdfbv zFE+2Z^0w;LufDp^Jo~0>-fEuiqu+f>*I#U2api5*t6zO}pLzC8*}T;}-ABLsl&-(n zyyD8+s#m}I>OS-Ao3eSUdAg5&_bFX}v3bRnw^grx^}YMdw|~m!O?ld<-@d7?U#@)p zvh%CizGC~!uU_+|n^$b!lqT_+jrvA@4=>d z`cFBhuce=Sr?2{pPjjcQ`CeaiO9(>(o^r!JdUyn5ce ztz7%7+oxXae6e}@)sN=dN4))GeANAttzSF$&@2z`Z_31CJJazL**S|N{KI+BOeW!Wl-&^;->eF9b zd8M1LUwt)KpZb)i`%d$pn*Zqcu+R5WU(MC0KIQ4Y)BLCAKl(lF^S#tpbM>iDdAjd3 z|Ec+peh>S6FZI=2ed<%5?mNwYYW~N5k7@r_|Iz<{+_(Dm7gyd^z53OwZ|bkS((T)u zYoFD+_f)?A;>z2qSHF7oP5qTux_x_d?Xz0TzOmd>Q}G6slW0{w{LH*eOBw< zQ~COfD{re_{p!^>^;cf$_8rZqzaP7nuhaYY-uv42nch3yf3<(Q@816Dea~&5>AlnU zTQ}d4z1aCF zSAOZ$SM#M;e(9C3Zk~Fv^Q*b~)$La=c7Dp0UwZY`eCd^6dgZH|r(W#*YOa2D`_+q` zpK|4wUVSxRdgYg1`ReAW7dyY2t6$xI^8|OtLOKAPxrM?eahxd`RKmdZ}mRr*{?q3%2Qv>?rXmKl+By+(S5bw>V3?!Uwz7z zr@orq*L?LUn>Xd7`)a?{`Ixg z@5}D*oO-eIQ?7jVY5vr2U-91g_S;vw{=K>OQJ?a3|Go3CZJ+7A`n{)g{d;rmqdw*7 z{(I+N+dk8K^?Oh0`uFDAM}5lE{rAqlwtcqV>-_Y-qy4q7dhvAs%2S`_PyO~4SANZx zZoYo?YvbBieadz3>b&YJef2)mJpJxly8g%J+E;zbb>Hf|>MMQqKGQt??pwP4$L88s zeady;>b&YJef2(D^PI2!SLv7j-mK=lN5A@%x8^y&nvcH6KJO`A|Fv`7qdsNt-#TyJ zR^GbLbY8#vm#+WXx$aS)viEPDH*YI%-Df(l-~CJ1f9+iNs88AZx6YflmACFQoiG2^ zd#-)n``uT4%9W>nG<%Qz>Qk;f^;=`#%lp))TzTq8v-jAqKIO_&zcu!~yia|~m8X6* zdyoCvKW%Hgp`+mNM`jpMPHlE(2-+M~ef9qWDp+4nn zf4}KH`n|7o{kP8b9_mxR_V=6Kqu={V*MI9=@1Z{BYk$A#J^H<`bp6kb>wVOx{Mhd~ z-B-W+m#+W0bG?`Plpp(jr~B%6|I+n8cdqwRpYmhB?{r`N?q9n8=guE{ukZdANqx$X z{l3$E^}B!R`X7r|{Tuin?Zt@~D<`U`*hVe$rB|NqqdWS2|4)nB@u`qf|f t<>$+vk&m%IFM8Eq_~qxzpOKHTKQDUKUwHQO<$Hbe&Fc><|CvAK{{yMEo2mc+ literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/background/backround_lvl3_3.npy b/src/jaxatari/games/sprites/up_n_down/background/backround_lvl3_3.npy new file mode 100644 index 0000000000000000000000000000000000000000..f3ddde6c97393cbfaa1b1aed080dc18822388e46 GIT binary patch literal 38432 zcmeI(ziuQ&6ov66a^7vS;CPB#7p zbA6_#)LqrR(n$09)UCPSaccrg`_EV3e);ux?>+nd*&q8?H+MhW?myq{|NQO6{%W`X z@y+e8x7V+~e{=il=J4~E*FWFgy!-jxPuIWPyyFi(`}k`2;?s|=cAtEBwfk###=rOf z{_mTIo?ma@K7VNW>baYz?B24w!`)BWVaaEZ>zlHDIB(t?>yxs+&YO?M`lf6j&YSnf`lPI{^X8+mzA4*>^X9#=J}K+#y!mLX zZ_4)Jy!q1E{ZrP@dGpa&-<0jcdGp>_pOp1=-h4FHH)Z>9-n=)~CuMz|Hy@4lP1!!2 zH}8%0Nm*a#%|~N>Q??K1&6m#ZpR#_=n~%o&rfeV1oA<{0q^z&==A*H`Dcgth=Do2# zDeLRJ`DmyBSG}bp|`*7a8H`b?O{XB0z8ta>~ zeK>F48|#y@zRsJE#`>mgAI_Wi#`>hJuk+@kvA!wWhx6vWu|6s5>%94BtZ&Np;k_pNjSKy!mLXZ_4)Jym@b|Ps;i_Z$299o3ed4Z{8d0ld`_fn~%o& zrfeV1oA<{0q^z&==A*H`Dcgth=Do2#Q@J<)bl*$#bLx(CpQW8Y<@S_JA7{)*BeLb&#m3zK-u6)f?9<86L{?7N_PqlyLYo2oT`Ag4JU(YYq@Abc5Jk0U3A9pZ6 zhrM%2e2&lYj4sF7;l26Vxa!0F9A0v#=lC4Y=yseP-kYzDt3J%n;U#x^j?eLoZpYc- zz4_X>>cf00FRkB8`*`2deCOV$>c>3grTIdw%U)^<|#&-2JBedfxr3-18eh{ySXNpLxoy`%d@qy!%zT=Qn=* zJlRLpr+LbL-laL+$Mfz}<(}Vod%k>Tj@jqqI6J&I-}vMzab_NhKU^{LKRpSSkB Z>MzZg-e>AlonN}|-1}AirFmZ|MAP;y?OQJpTGI-%RjvOvi|O~7e9D% z^X$d*CpSO*-jkcZ-qip1?We!|@9CO;zWCzBsq^QWZn?9*``1qA+q%=L*;?zpRr~(ke7}kH-o(gn(*1r+e`J=ewccB`@88Y$ zn^^BnjQl3u@5l5*5B`E?)$O( zcYptGeapA`-1pc0{k!$|`mGg+w{8kQ-9ZU-T5Nv+Mo7>>znAqJfBO4>pU!`*-}t%K{{8LSUg941 zeT>ti^&9WjqyDaY-<*AKZ@gQN`n%rm%iiRDe7~RG_rv?%_bt8&Gd<(nVg1Ib zXEpV9J@)(`}s|(yxMR3 zsAo0xc0Klu+5O3Rf8*Bw&ivm$K0UqX$NI^u^?kcHyZ_tV-KTnQ_VdQZtv{Ww>T@4! z>h0R!)cK8g&uV?Hb@w~x$yu|XHTzjxKY6w0zFqJ4b^Zpb?}`2ATKo6f+q%0K?`bt} z%zkpKInUbW$!(szT65p7-+pt|d$OPRG|qn3oM+8`t8JdVTHF3>=hb_%pZ7D)e%734 z&3>zGp1fMy{%hyed$OPRGtPe2oM+8`t8JdVTHF3>=jlB;pWc&tte^km`6>A^J@RUu zczX=9)qAm@H!#kAtEtD@=8c;kd9}9vkDXWV#eUw$IQy-p9&4L7ZhGX^+V($oUcDFl zc^~8Kx0-saZQi))kymTm|JZqYFU}wRUY~#d{nPmm^PIDX{7${cRzKt>=QF&=IlKFv z{Aji1A+Od)yGPY;^&af!KI80Py|z5$)q3^(slVzy*w6jO*}r;idC05v>iJWD)qAj? z`;D`IRQb2z|PwUn1VgAzmHgEfn&Y7Qi$g4H=R;#Ods{PE%dGe#x%rm;? zJ*ZEEoPrh1BzoTpBr9S!5YUUYT^B&YEU#+I!(KYi@ zpZsVw^Ng-}59*VzR@3k3nt7>Dezcl-(wh0HN4|PZ|7p#<)FVH7&3tLidr*&j^_u?E znt7>5e)O98(wg_69{K7u{iikaQjh%THS?u4??FBC)oc1sYv!dM`6~5`55IGI$Pe>% z*AMYHbEf+q)SsEPV>k14*Bu!;HM{RY{mv{iw3)BFo{=%x*?kY{Ctg>aWxnou#T(q` zsrw$(rx$Y8>_1w~JeRIlzt^Sj!F{XW$NVwRrEBvyn_v1K+&4Rax}SM2U8nxL=P!K^ z?(5Fit2-7Z>8>qLb-(F#^QvF# zH@`MdZhGCEe!6SRQ{8WR-Ms48`pvJ+lbc>Qr=RZH@>KVmUN^7$wSMz!^W>&Co6~Q1 zUCmqdQ`JxVtNyIN>c{lRtNLj__jlJ-zg0g~{j|U8&-$x=Opm;(pZ0TqcU|>c^;6YP z`>Xz}zv{>I$gBEkKlgXnRlij~RsFQT>d*SCeoQarRX^6B`WZcMe#vb=`$yNde>AW9 zukNqvTYq(bcfa{1H$UtjUEBWAyz0NYzp8Kj)&1T5=9k?3uzz%I`$zMt|LXp#zV%o4 zclVoLa`VIf(Y5U#&Aa_~>yN%ix1QaD`>nsqslQrndFa>rtDO2rt1U0{SbvpMf3@23 z(69AZIrWcLTVCd|{wk;bYPIE|U+b@O>L0DPyv$?$RZjh**VX*i|J6NK`S&opeb_y! zyxZ@k^{e@<|I+?v@3VVU`RqKS_gC{<|7d@!>)Ac3eD%Clf7SffU)_K7e!EAN?|Z!c z|G&ul9sNG$pLs{uwtsa#`X1CRfNXul>Hx-^1wN_xsxKYxJ8s*Sq%j zW%x47{A~Vpo=5(f<+EKj-_6Zm%FVy&+x%>9{!?!GOyB0ax%o@E`8R!=pUusG z$}OMi+k7`Se7Jx%p4IHs8(7U&_tD>D&BlZvInl`ApyDySe#Gx%oGJo1e|if66VN>DzoaH-E;> zUpjC4-Q4`A-29uq&3AM2mvZxO`ZhnCoBxzsKGV1PZf^ckZvIW*=4W&BpK{A*`ZnLq z&0osRzv*n40QvKQaZGW0Coi~3eH-DyY^Rv16Pr2nYeVgy* z<}c;u-}G&MHaGt%w|u5=^WEJ1rQH0RzRl0(=0D|@&-87+o14Fscl)2MpWcV_%tOwa z{j1gVo7T){dh9=1ZTYI4`E8#4tJUVe%9+pR*?+X!@>MzW+dTVMtIdCvGoQ_~|7f-4 zt8(VIdG@bXoBt|jKAUI%tooDx{OI(MAIq2Whj?0PCe3g2D`oBQV);|v;rWu9X?~l( zWR@AbEPu*pWE#0M&2RG~y{@Qb`BT1P=8C5@zs*;?9I@B(r~HUFIrp6AxB2s3kMVcQ zpYo&MEX{B8N57BxxBMx${j2jdzs;|nKlNw%Q=a;pJ)h>c`Punq@3;G;e0IKcf12Os zQ-7<^TmF=R_T+uzO2zwNjAl-vH<-13<}n@_pz z@8;&;_S<~QZU1a;`OKfqr`+~;bMtTeZN8hEzuDaKn?IXRx$W=f=HK?)e9CSAY;O6? zpUtP-_IGphZ~JXN<+gt|w|wT$=2LF_ySe$d{WhO++drFIKJ#buDYyOI-2A8eO+V$S zzwUXR_T+uzO2zwNjAl-vH<-13<}n@_pz@8;&; z_S<~QZU1a;`OKfqr`+~;^KSp$`rY?Q^)0{Y+q`kppUqSMHg9}({&atu&*qJ%{=4T> z|2A*DJ72ecn$PBqcjrs>Q~x$^JoP_&KFw$I#%JeG_glV{f93C=|NHsrG5)0YvHAZ4 D87?yv literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/background/backround_lvl3_5.npy b/src/jaxatari/games/sprites/up_n_down/background/backround_lvl3_5.npy new file mode 100644 index 0000000000000000000000000000000000000000..05412646974a10c807d9f50d6cf1b8dcd2877b91 GIT binary patch literal 71264 zcmeI3v5Hes5Qeq&Dbht@Emjo4(!ydvu@SM7MOQ4um8{yRudt0TVsGb*MNdgI&j?b*ef)$-9|`TfiB@@TPqcfR_( zI{o+Z{6>wS7% zUS6(SUi);Hp|iByZL@Z$C(r($;Y0WT(C@eJtRK4nhkif!`v31zhga~0{|>IACw#$I zu;7UB1z!%Xq9=U8SFqrS@C9EEuA(P=!B?>0i0}np4z8jne8E?+;E3=AUk0i0}np4z8jne8E?+;E3=AUk0i0}np4z8jne8E?+;E3=AUk0i0}np4z8jne8E?+;E3=AUk0i0}np4z8jne8E?+;E3=AUk0i0}np4z8jne8E?+;E3=AUk0i0}np4z8jne8E?+;E3=AUko~Kn$ zeV@7S-|D}vpV|*SPph8#K6Br{)qh<-wI6z(Rz3B7=DvTc|Ji!CKiU4(yq;IrYaV^| z?+=+@ns3!^|69+e>-D@@Uh8|;nO~Z3)o#CA&!_A4yjfoBd)Jv?ns3!^zgy3z>-D@@ zUh8|;Rll^pRk!+oa&ofXZ|!y|Ev3D^?IIay{_+FXZ~qk{nGs0batO=Uh8|;%7gk* z&1?O>>(=|~di3=_->W~$N0!(6eb?E2s(G#NUF-bSk7{1)_g%N%SJ$K4zQ2?g^;7fv zrmN@8>NSssuWbf0{pZi5S$#FH^}TELU;U`&TkCawRcGhdep!7rul2oax8K$NTA%61t3T`Q zm)39RHt*ee-LH4;_Pg3&>ocwM)B3UVx?k^l>Urq?Rju>W^<(FCzuxuK^U(dPTIZ+h z$Ik11z3Zvxq5D^VpHk&MRPe!c6d=b`&owa!o1kDb^3de>9WL-((0ou95BJFolo zuBV=d?qAh9KV3g|Uia%=XY+9Tqx)C2`+W4j`tOT#^P|d>*3a#}oOfJam6xh!UgOSd zZhlvJ()zjGm-CMMc~D+{|9-HpyqMRx^O~FA6_$o-=XPJtJMQOUh?&YW<~8oT=H_>W zrJ>rn-Iw!@`*|2*rt+-!|G&rmul_!q`cYwNs8;)-ulKt*e});>=amWFEa`+C3s0wf+q AY5)KL literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/background/backround_lvl3_6.npy b/src/jaxatari/games/sprites/up_n_down/background/backround_lvl3_6.npy new file mode 100644 index 0000000000000000000000000000000000000000..630f3cc975e6b7986da5da995621ee4521423776 GIT binary patch literal 98624 zcmeI4zm8m06U0r*Q*5>%nZO7L2nonQBnS~93ojx-7_ET>cm*fC2+zQa8BbtoZei}v z#rbpl%+x3OV|%MlcU7G_8ZiR?{QS$$zWDmB2fse}?e@vtvv2QjKfbyB{g+3#4{vV2 zdwT!#{o^0Md3yik?)3XlAOG;|?&a^FegF8UyO;Glk3M?;;mxBDU;f_*?>@Zw(Tr#eZK3r_NzaHYkzA^`xJNU zR({uh-t}Ai)jz_uzqO`)io10yziT`1`m6owm*QG$fA3oNX|4MdcWa%Odix!i>!g4M@18$zQ|9n<{-e&9Mhx+B#>M5Pq{@%5_pZW>!&A->Y zPwxMIal{|;i*T2-&U4SZ{jK}-yj#0`FRkD8OZ#0vTF=kkr+nSo<$G!Uu3y^k`q6rR z_CDq7)-K;m>v#Rqe%FuI^RxFU-`3jYzgqv%{^bUEN_P6TqJ&*YG-ly~2n)t4cyMAbYtNz~eh)?f*IwWTkE{lYFDrJt6svjfAyO5D(=>o=1;4BJ-@m?#Gmvk?$)%wI_~P# ze$_{~_OD)(Ud3DM>VEasNAs%t(aO)&PyK|q@*RDi_!F-D+Ar{dfu(8 z=^xEs=O3+q^!`@8dfu%^>!bNDU+t&)inpHEdDZJyy?WlQkM_sqtNkwj*7NSZqx*a7 z)p>3`+7FkH_P6ruJ+GC|=<~hx>O8let#>v*-8WkA(er!j)p>3`+7FkH_P6ruJ+GC| z=<~hx>O8let#>v*-8WkA(eqpN(s{+*`c3`$ukX>^ecG>c3D^GA|NeK79Z$$_#ogNB z^15DkpZ4oa!nHqj@}-$MuQea%*y(4*Q@7cWd3}*6u#-*L{R*f9kfpsoz?w zzlyuH&ePhRr~U3c?N8m7C-qP7M{BoV&ugvowAOyNcIRn->f~!LbDEF6-V>(%sgtk0 z%xFILdXuldSyJy$wfg_Mp1t4c`PucGe!uprPx-z2ef~P;zrWV%x1OI}ySY$)?)mh4 zIc)V?&(E&i`&NGL`Sd4Wdl^W+ynO6=`6gJa-+F#_{U#dEdz9ZfmY@4)uhnloKf87_ zq5RzQ>Gv|&>bIVsUAyg^Qrf@TAjSi&5P^zYI#z=uHN*|&7+E16S_8-mv()X+1Zhf>rE?@0Gn*XKmSHIo*Xn$P3 z+J7|vOW&`4yYBO)%mI4R9Dwa z{1vaRXLWw+H`Uek5`V?3>sg(j`b~9py~LmJ>iVklBYz2xzU}wZ)%}>wkMcwKY`*S3 z;!n7{Fa6v1M*b3B-B03Acy)c%`H{bbN8k3nR`+8zKgti`v-!IFh(F=(zVvV38~ICk zbw7ze;nnq3=STh$9(~*QTHTM?{3t(!&*tmyBmRWD`_jLCZ{#oG)%_&?gjd&Bogev2 zc=T=GYjrQHb2S_;j{U=`-nf`?!NSI z-y8W$cy&LCKjGE&Rp&?k5*~fq_gdYL+59L!gwN*d?j!z$yZh3=eQ)G1;nn>l{)AW8 zSDhdEOL+8c-)nV0X7i)`5I&o)yN~!2?(R$f_Pvq6gje^I_!C}TUv+-uFX7R*eXrI1 zn9Yy!L-=gI?mpsAxVta?+xJHP5?t-QR21pL}gU&wKlk ze8u11nL3}Be9g_*x%b!jv*zs@J?N=9es14y{Mh=@nzi`3^)r5K{YbvzM{5Q~p3i&j zo6nJVmD#zgWFQ$x29kkfAQ?yol7VC(8At|_fn*>VNCuLDWFQ$x29kkfAQ?yol7VC( z8At|_fn*>VNCuLDWFQ$x29kkfAQ?yol7VC(8At}|We66eBo$}SXI{!=Or}v?C zb^Xd$>+1Y3ouA%^*46bZU#+Y2zjS_jA6i$}uY9$x&i~T*ssB`0*H8Qjudc5;Kk}FG z=&SqpW)<#+@m1q~~6Ab$z3uATT3UR_^xe&jFV(U;H1y$n!(2;ZAs0VDo|SJ1FB z2lAKj>e`7v;nnq3=STh$9)0&Hhgjdk8G6(XP@ao!$KjGE&Rp&?k5*~f|eB8?b<%jUS*%dJ2 zPk03lD{~-!39qi5_!C}TUv+-uFX7Rb&&RzCP<{yCn_U4T{)AW1urdeom+@G>}I<*Z#w-sKbqMmW%jjwn>U^Q zobS!dOPPIa-{wuHKj%j?`=rdiwr}&M)1ULbnRzL*kL}yM>GbFPXl9?3+1K`M-gNqN zzBe;3W%jXsn>U^QoFC2XlQR3-zRjCXf6n)2=B3O&wr}&M)1ULBnSD}bU)#5N)9KIo z-pss|*~j*6-gNqNel)XB%Is_VHg7upIp3R^moodzP4}krqiGEy_tC_vybiDyy^7k{AgyMl-bwzZQgYHbG|n- zubSDPd8XSu_j@z*Q)Zvu`>8+YTRA-s=b2C4*!E4QKj*8N`6;ts@BP%D^R1kohx5#* zZfyIe)1ULz%>0zuulIiH&-qqP&%=4gn4dB`a=%))Jm#Ov^u3gsN1tll@|fS7nV&NIaKBo&Jm#Ov^u3gsN1tll z@|fS7nV&NIaKBo&Jm#Ov^u3gsN1tll@|fS7nV&NIaKBo&Jm#Ov^u3gsN1tll@|fS7 znV&NIaKBo&Jm#Ov^u3gsN1tll@|fS7nV&NIaKBo&Jm#Ov^u3gsN1tll@|fS7nV&NI zaKBo&JmwqIua%icziQp`nBSY3pECP!zgo9E<{Q(mm6=DsYTfdf-^$EOnfbPF^Q}7b z>6bG7Qf8j*+kC6eeERif=B3O&wr}&TI`iq5GW}9!p6%OwtImA-^=9U!%s#en^Q}7b z>6bG7Qf8j*+kC6eeERif=B3O&wr}&TI`iq5GW}9!p6%OwtImA-^=9U!%s#en^Q}7b z>6bG7Qf8j*+kC6eeERif=B3O&wr}&TI`iq5GW}9!p6%OwtImA-^=9U!%s#en^Q}7b z>6bG7Qf8j*+kC6eeERif=B3O&wr}&TI`iq5GW}9!p6%OwtImA-^=9U!%s#en^Q}7b z>6bG7Qf8j*+kC6eeERif=B3O&wr}&TI`iq5GW}9!p6%OwtImA-^=9U!%s#en^Q}7b z>6bG7Qf8j*+kC6eeERif=B3O&wr}&TI`iq5GW}9!p6%OwtImA-wKDxvW`6H|^EaJ+ zIB(4Tl9(<-+SNuO=lm@w=(@xW`6H|^EaJ+ zIB(4Tl9(<-+SNuO=lm@w=(@xW(+<; z$lOoa{ATMs5Bsz-{ZnQ>_h;+YhyBRhPucut>pTzpv@-ouW(+<;$lOoa{ATMs z5Bsz-{ZnQ>_h;+YhyBRhPucut>pTzpv@-ouW(+<;$lOoa{ATMs5Bsz-{ZnQ> z_h;+YhyBRhPucut>pTzpv@-ouW(+<;$lOoa{ATMs5Bsz-{ZnQ>_h;+YhyBRh zPucut>pah^KL7sZucu2s^h?Bsrey5(6v?iCKA^KekrH^oM&FOZh4l^ePjBO zxj$RCKI~^qzm!vd&NHuCw>-<|zA^pC+@GymANDh*U&^UJ=b2ZnTb|`}-Mw%`_%W&JYULt zKfm=)?}PK^f7^B5N6N?jymQ*Rv)g_@cD}8u%1`TJu-m zKj#j%W>!C6b-$M{U!E#I)%~meEAQsUySo2u|Kt9)|Gj|Yezxz&zQ4!)ZQtK}TbJm>ouhwUsyH{@Y{g{8Rb5!|Q|7yMJT%&xe@5lT{ zo%7Nh>tC&3+Iy9~>iaSORqo+UJ*@w1{Y}01^7G^V?w+>Z>bRfn``Nm?8|}Ei?fX0K zXZwDR@7HlZ_8q9ykKfPj?_c%(RK9KEegEhCrx~B>`=$SkGrpe5KGpns_W$EwSKlxF zXWudVRP*ff&d#sCU-}>Co%7GBS@+KOujX0je3e^$KKj4PJ^uXrzgN7k-dug}z5CLy zy8mtW=Y96(>ig^6mwwg#Z@WM5vo}}YU+=#3tL}f>{du3gx%&Qk_oZKT|6cxge|mb_ zF+T_L#N$ z`>oz@k8eJ*^Xl)ndcToxwa2X0-*5GPdwlbeomYRq)%%Tnt376|{(h_X+vA&$?7aH> zt=@0sTkSDx_4ix7-yYw5Wari2Z}om7-)fIptH0ms{r33gBRjADeyjHz`Br<(TK)Z2 z@3+S{AK7{J_glT+$hX>K*6Q!KdcQrs`N+;sojeRD_ zKzxet&EjO7vv;;I=+AZD#D@Z|KDAJ0$E zPENmheEGxW{YUp6Up}~a{^viw|J{>|XMg_5xA(ulc*gIZ-#$M(IluG4*~#tq&Q5+k z$$xwE_rJb4yvm=aPwyOF{L%YG?N$7-{@vJ(?5Ccxe%08%=2uTyzi#Xv_ES$;ziMn> z^Q)(Cg_Up2O``PEa_uN%9E{nS&|uNvFe{OT#|*Nxr7e(EXfSB>p!e)W{~ z>&EV3KlPOLtH$;RVa=ZtOnxQ%_mHYHVNgtEa4AH+B#Esi&-8HMXz$)l=558@q@7)Kk{48r#?W z>M85jjorh3>M842jqPiG^_2DN#_nN1^_2Ci#`ZP8ddm8BWB0J1ddm7$WBZz4J!Sp6 zv3uB0J!So>v36vVPUrzUEg?S-)=V9`;jDS-)y*U-PS{tY0^F5BsU7tY0;@uldze)~_48hyB!3 z)~_1d*Zk@!>(`Cl!+z>1>sO8KYku{V_3Os&VL$bh^{dAAHNSew`gLRXu%CL$`c-55 znqNI-{kpMx*iSuW{i?Bj&99!ae%;tT?5Ccxe%08%=2uTyzi#Xv_ES$;ziMn>^Q)(< zUpICS`>Cg_Up2O``PEa_uN%9E{nS&|uNvETmEZ4Iyz=X&yNCVMQ`WB<+t>W+EB1YN zb`Sfir>tK!wy*isSM2-l>>l=2Pg%ceY+v)Muh{q9**)y5p0a+`*uLghU$O7IvwPTA zJ!So>v3>l=0Pg%ceY+v)Mr>tK$b`Sfhr>tK! zwy*isQ`YZg`Sai2KUDZ}4>PN$tY0;@uldze)~_48hyB!3)~_1d*Zk@!>(`Cl!+z>1 z>sO8KYku{V_3Os&VL$bh^{d9KzV=ytukQQqyy~x>a`pY}YajI$`@TE7hkeyk)~_1d z*Zk@$_I-DD5BsX8tY0;@uldzi?ECKQ9`;pFS-)y*U-PRctXD^N4?C)-tY0;@uldyz z)~h4ChaJ^Z)~_1d*Zk@!>(`Cl!+z>1>sO8KYku{V_3Os&VL$bh^{dAAHNSew`gLRX zu%CL$`c-55nqNI-{kpMx*iSuW{i?Bj&99!ae%;tT?5Ccxe%08%=2uTyzi#Xv_ES$; zzv^s%`>3zj_ubh&?5m!#e%08%=2uTyzi#Xv_ES$;ziMn>^Q)(Cg_Up2O` z`PEa_uN%9E{nS&|uNvFe{OT#|*Nxr7e(EXfSB>p!e)W{~>&EV3KlPOLtH$;(`Cl!+z>1 z>sO8KYku{V_3Os&VL$bh^{dAAHNSew`gLRXu%CL$`c-55nqNI-{kpMx*iSuW{i?Bj z&99!ae%;tT?5Ccxe%08%=2uTyzi#Xv_ES$;ziMn>^Q)(Cg_Up2O``PEa_ zuN%9E{nS&|uNvFe{OT#|*Nxr7e(EXfH{9dD^O|4B`E{Hdm>ZZIm>ZZIm>ZZIm>ZZI zm>ZZIm>ZZIm>ZZIm>ZZIm>ZZIm>ZZIm>ZZIm>ZZIm>ZZIm>ZZIm>ZZIm>ZZIm>ZZI z`1c0h{^-_WerAKYfw_UXfw_UXfw_UXfw_UXfw_UXfw_UXfw_UXfw_UXfw_UXfw_UX zfw_UXfw_UXf!F&69zJ~Qu;u4afwsjBRciU&x$NpP+_1CMWZ{5T9)%tJEYyYjh zHP7DOZ{5T9d+T@IdF;QHuRHJm?LT{L`}f4z@A!VR&+&a$f4|Ru$M>6kj_)(~HT$e? z;5zmB`}4o|_p0~M?>c`i_uy^clebjjDcjfg>fPAB zDZ7X7)vK|6Qns(})w{8MQ+5yEtM8r7pR)Z{@BO|yullZZ`={Uges8?$v(oL8e((Eg zyz0Bs?VEn@`@Qk1&q}vX`n~U~@v84iw{QBr@At;5J}ccm>G!^`#;d+7-M;DfzQ1W+ zbKJEyZT<;_j_aWrfi?B z@Aca|r+xk2e(Eds{oZ($f2G@J^}V|9_s-_u%JxnDe7`r|+9!SQ_j}u8mB;)k>z}^& zd;M;j&6l$I)AxR_-`?5$Dcdi7@AvxMG@CDF^QZ6qUcbGw`BS!E`rhyLyJvz*^z7?DIy1w7z{|8{z^<{K>t=^>H{{uEo42=K) literal 0 HcmV?d00001 diff --git a/src/jaxatari/games/sprites/up_n_down/background/backround_lvl3_9.npy b/src/jaxatari/games/sprites/up_n_down/background/backround_lvl3_9.npy new file mode 100644 index 0000000000000000000000000000000000000000..fdef4bd73c73eb9703dd7cec4f27fbec7badad5e GIT binary patch literal 11072 zcmeH|%}T>S6ou>7r8Y!^cL78#Cd zGMP&Xfqp%A&fIfo>G%2d*~QJykV`>LGtBR=k%Tr!JcwrtWAzbsz2@uieYfgSyXWxt!K*o<4UCaaWse)@MWPJfu6v z{+fT!V?WLP_`b_(e?C9!y4n}#$9|gsFRl0E%g67F%0BPG4R32ic zZSJJ@!TD|O5~COUP&vlT%Fb#ZoLBC$ILkg%w%EyLCbbXW{HC|X_rCn9ZmlEcXFlpG z)6W{S59XupW1f$lnU{H}t4u#@%s!Zpx{rB2c4l7Yp{_FhtTFpwKI%T^`Pi9xnTNW{ z^s~n7gZZedOuyKedAUDzmFZ`V*$4AcSDAjXG4pbN>MGOE8nX}PqpmXjVq@mz{?t{b PpEYJ5%tu{i`gP5}h++X0 literal 0 HcmV?d00001 From 1fea1346f32003b73f969d98fb107aedd73646de Mon Sep 17 00:00:00 2001 From: Sebastian Jilge Date: Wed, 25 Mar 2026 17:26:36 +0100 Subject: [PATCH 75/76] small fix for training with mods --- src/jaxatari/games/jax_upndown.py | 29 ++++++++++++++----- .../games/mods/upndown/upndown_mod_plugins.py | 2 +- .../games/mods/upndown_mod_plugins.py | 2 +- 3 files changed, 23 insertions(+), 10 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index 63041d7e7..a46d358eb 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -104,7 +104,8 @@ class UpNDownConstants(struct.PyTreeNode): LIFE_BOTTOM_X_POSITIONS: chex.Array = struct.field(default_factory=lambda: jnp.array([13, 18, 25, 33, 33])) # X positions for 5 life cars LIFE_BOTTOM_Y: int = 195 # Collectible constants - unified dynamic spawning - MAX_COLLECTIBLES: int = 1 # Maximum collectibles that can exist at once (pool of mixed types) + MAX_COLLECTIBLES: int = 4 # Fixed collectible pool size used for observation/state schema stability + MAX_ACTIVE_COLLECTIBLES: int = 1 # Runtime cap of simultaneously active collectibles COLLECTIBLE_SPAWN_INTERVAL: int = 200 # Steps between spawn attempts COLLECTIBLE_DESPAWN_DISTANCE: int = 500 # Distance beyond which collectibles despawn # Collectible types (indices for type field) @@ -239,6 +240,7 @@ class UpNDownObservation: jump_cooldown: chex.Array is_on_steep_road: chex.Array round_started: chex.Array + level: chex.Array def _replace(self, **kwargs): return self.replace(**kwargs) @@ -277,17 +279,17 @@ def __init__(self, consts: UpNDownConstants = None, reward_funcs: list[callable] # Player car: 10 values (x, y, w, h, speed, type, road, road_index_A, road_index_B, direction_x) # Enemy cars: MAX_ENEMY_CARS * 12 = 8 * 12 = 96 (x, y, w, h, speed, type, road, road_index_A, road_index_B, direction_x, active, age) # Flags: NUM_FLAGS * 5 = 8 * 5 = 40 (y, road, segment, color, collected per flag) - # Collectibles: MAX_COLLECTIBLES * 6 = 1 * 6 = 6 (y, x, road, color_idx, type, active per collectible) + # Collectibles: MAX_COLLECTIBLES * 6 = 4 * 6 = 24 (y, x, road, color_idx, type, active per collectible) # Flags collected mask: NUM_FLAGS = 8 - # Score, lives, is_jumping, jump_cooldown, is_on_steep_road, round_started: 6 - # Total: 10 + 96 + 40 + 6 + 8 + 6 = 166 + # Score, lives, is_jumping, jump_cooldown, is_on_steep_road, round_started, level: 7 + # Total: 10 + 96 + 40 + 6 + 8 + 7 = 167 self.obs_size = ( 10 + # player car self.consts.MAX_ENEMY_CARS * 12 + # enemy cars (all fields) self.consts.NUM_FLAGS * 5 + # flags self.consts.MAX_COLLECTIBLES * 6 + # collectibles (all fields) self.consts.NUM_FLAGS + # flags_collected_mask - 6 # score, lives, is_jumping, jump_cooldown, is_on_steep_road, round_started + 7 # score, lives, is_jumping, jump_cooldown, is_on_steep_road, round_started, level ) # Speed dividers for movement timing (indexed by speed level) self._speed_dividers = jnp.array([0, 1, 2, 4, 8, 16, 16, 16, 16]) @@ -993,8 +995,10 @@ def _collectible_step(self, state: UpNDownState, new_player_y: chex.Array, playe state.collectible_spawn_timer - 1, ) - # Attempt to spawn when timer hits 0 + # Attempt to spawn when timer hits 0 and active collectible cap allows it should_spawn = state.collectible_spawn_timer <= 0 + active_collectible_count = jnp.sum(state.collectibles.active.astype(jnp.int32)) + can_spawn_under_cap = active_collectible_count < self.consts.MAX_ACTIVE_COLLECTIBLES inactive_mask = ~state.collectibles.active first_inactive = jnp.argmax(inactive_mask.astype(jnp.int32)) @@ -1025,7 +1029,12 @@ def _collectible_step(self, state: UpNDownState, new_player_y: chex.Array, playe ) # Create mask for which collectibles to update - update_mask = (jnp.arange(self.consts.MAX_COLLECTIBLES) == spawn_idx) & should_spawn & has_inactive_slot + update_mask = ( + (jnp.arange(self.consts.MAX_COLLECTIBLES) == spawn_idx) + & should_spawn + & has_inactive_slot + & can_spawn_under_cap + ) # Update collectibles with proper masking - spawn new items spawned_y = jnp.where(update_mask, y_spawn, state.collectibles.y) @@ -2071,6 +2080,7 @@ def _get_observation(self, state: UpNDownState) -> UpNDownObservation: jump_cooldown=jnp.int32(state.jump_cooldown), is_on_steep_road=jnp.int32(is_on_steep_road), round_started=jnp.int32(state.round_started), + level=jnp.int32(state.level), ) @partial(jax.jit, static_argnums=(0,)) @@ -2140,7 +2150,7 @@ def obs_to_flat_array(self, obs: UpNDownObservation) -> jnp.ndarray: - Flags: NUM_FLAGS * 5 values (y, road, segment, color, collected per flag) - Collectibles: MAX_COLLECTIBLES * 6 values (y, x, road, color_idx, type, active per collectible) - Flags collected mask: NUM_FLAGS values - - Score, lives, is_jumping, jump_cooldown, is_on_steep_road, round_started: 6 values + - Score, lives, is_jumping, jump_cooldown, is_on_steep_road, round_started, level: 7 values """ return jnp.concatenate([ self.flatten_car(obs.player_car), @@ -2154,6 +2164,7 @@ def obs_to_flat_array(self, obs: UpNDownObservation) -> jnp.ndarray: jnp.array([obs.jump_cooldown], dtype=jnp.int32), jnp.array([obs.is_on_steep_road], dtype=jnp.int32), jnp.array([obs.round_started], dtype=jnp.int32), + jnp.array([obs.level], dtype=jnp.int32), ]) def action_space(self) -> spaces.Discrete: @@ -2174,6 +2185,7 @@ def observation_space(self) -> spaces.Dict: - jump_cooldown: int (0-48) - is_on_steep_road: int (0 or 1) - round_started: int (0 or 1) + - level: int (0-2) """ return spaces.Dict({ "player_car": spaces.Dict({ @@ -2228,6 +2240,7 @@ def observation_space(self) -> spaces.Dict: "jump_cooldown": spaces.Box(low=0, high=48, shape=(), dtype=jnp.int32), "is_on_steep_road": spaces.Box(low=0, high=1, shape=(), dtype=jnp.int32), "round_started": spaces.Box(low=0, high=1, shape=(), dtype=jnp.int32), + "level": spaces.Box(low=0, high=self.consts.LEVEL_COUNT - 1, shape=(), dtype=jnp.int32), }) def image_space(self) -> spaces.Box: diff --git a/src/jaxatari/games/mods/upndown/upndown_mod_plugins.py b/src/jaxatari/games/mods/upndown/upndown_mod_plugins.py index a23f9553f..2ac445744 100644 --- a/src/jaxatari/games/mods/upndown/upndown_mod_plugins.py +++ b/src/jaxatari/games/mods/upndown/upndown_mod_plugins.py @@ -27,7 +27,7 @@ class HigherPlayerSpeedMod(JaxAtariInternalModPlugin): class MoreCollectiblesMod(JaxAtariInternalModPlugin): constants_overrides = { - "MAX_COLLECTIBLES": 4, + "MAX_ACTIVE_COLLECTIBLES": 4, "COLLECTIBLE_SPAWN_INTERVAL": 120, } diff --git a/src/jaxatari/games/mods/upndown_mod_plugins.py b/src/jaxatari/games/mods/upndown_mod_plugins.py index 5382e0efa..909d8e187 100644 --- a/src/jaxatari/games/mods/upndown_mod_plugins.py +++ b/src/jaxatari/games/mods/upndown_mod_plugins.py @@ -21,7 +21,7 @@ class HigherPlayerSpeedMod(JaxAtariInternalModPlugin): class MoreCollectiblesMod(JaxAtariInternalModPlugin): constants_overrides = { - "MAX_COLLECTIBLES": 4, + "MAX_ACTIVE_COLLECTIBLES": 4, "COLLECTIBLE_SPAWN_INTERVAL": 120, } From caa0fc7c8abe7d30c63ee4e5fdd30cf42120dd51 Mon Sep 17 00:00:00 2001 From: Sebastian Date: Sat, 28 Mar 2026 15:55:32 +0100 Subject: [PATCH 76/76] fixes to improve training --- src/jaxatari/games/jax_upndown.py | 162 ++++++++++++++---- .../games/mods/upndown/upndown_mod_plugins.py | 9 +- 2 files changed, 135 insertions(+), 36 deletions(-) diff --git a/src/jaxatari/games/jax_upndown.py b/src/jaxatari/games/jax_upndown.py index a46d358eb..ad39bd572 100644 --- a/src/jaxatari/games/jax_upndown.py +++ b/src/jaxatari/games/jax_upndown.py @@ -54,7 +54,7 @@ class UpNDownConstants(struct.PyTreeNode): GROUND_COLLISION_DISTANCE: float = 3.0 # Tight collision distance for ground collisions LATE_JUMP_ENEMY_SCORE: int = 400 STEEP_ROAD_SPEED_REDUCTION_INTERVAL: int = 8 # Frames between each speed reduction on steep roads - PASSIVE_SCORE_INTERVAL: int = 60 # Steps between passive score awards + PASSIVE_SCORE_INTERVAL: int = 45 # Steps between passive score awards PASSIVE_SCORE_AMOUNT: int = 10 # Points awarded for passive scoring COLLISION_THRESHOLD: float = 5.0 # Distance threshold for flag/collectible collision ACCELERATION_INTERVAL: int = 6 # Frames between speed changes when holding up/down @@ -105,8 +105,8 @@ class UpNDownConstants(struct.PyTreeNode): LIFE_BOTTOM_Y: int = 195 # Collectible constants - unified dynamic spawning MAX_COLLECTIBLES: int = 4 # Fixed collectible pool size used for observation/state schema stability - MAX_ACTIVE_COLLECTIBLES: int = 1 # Runtime cap of simultaneously active collectibles - COLLECTIBLE_SPAWN_INTERVAL: int = 200 # Steps between spawn attempts + MAX_ACTIVE_COLLECTIBLES: int = 2 # Runtime cap of simultaneously active collectibles + COLLECTIBLE_SPAWN_INTERVAL: int = 160 # Steps between spawn attempts COLLECTIBLE_DESPAWN_DISTANCE: int = 500 # Distance beyond which collectibles despawn # Collectible types (indices for type field) COLLECTIBLE_TYPE_CHERRY: int = 0 @@ -221,6 +221,7 @@ class UpNDownState: jump_key_released: chex.Array # True if jump button was NOT pressed in previous step last_extra_life_score: chex.Array # Score at which last extra life was awarded jump_total_duration: chex.Array # Total duration of the current/last jump for rendering arc + level_cycle_counter: chex.Array # Increments on each level transition to diversify RNG def _replace(self, **kwargs): return self.replace(**kwargs) @@ -239,6 +240,10 @@ class UpNDownObservation: is_jumping: chex.Array jump_cooldown: chex.Array is_on_steep_road: chex.Array + road_section_start_x: chex.Array + road_section_start_y: chex.Array + road_section_end_x: chex.Array + road_section_end_y: chex.Array round_started: chex.Array level: chex.Array @@ -254,6 +259,10 @@ class UpNDownInfo: movement_steps: jnp.ndarray # Steps since round started jump_slope: jnp.ndarray # Current jump trajectory slope player_road_segment: jnp.ndarray # Current road segment index + road_section_start_x: jnp.ndarray + road_section_start_y: jnp.ndarray + road_section_end_x: jnp.ndarray + road_section_end_y: jnp.ndarray def _replace(self, **kwargs): return self.replace(**kwargs) @@ -281,15 +290,18 @@ def __init__(self, consts: UpNDownConstants = None, reward_funcs: list[callable] # Flags: NUM_FLAGS * 5 = 8 * 5 = 40 (y, road, segment, color, collected per flag) # Collectibles: MAX_COLLECTIBLES * 6 = 4 * 6 = 24 (y, x, road, color_idx, type, active per collectible) # Flags collected mask: NUM_FLAGS = 8 - # Score, lives, is_jumping, jump_cooldown, is_on_steep_road, round_started, level: 7 - # Total: 10 + 96 + 40 + 6 + 8 + 7 = 167 + # Score/lives/jump state and geometry context: 11 scalar values + # (score, lives, is_jumping, jump_cooldown, is_on_steep_road, + # road_section_start_x, road_section_start_y, road_section_end_x, road_section_end_y, + # round_started, level) + # Total: 10 + 96 + 40 + 24 + 8 + 11 = 189 self.obs_size = ( 10 + # player car self.consts.MAX_ENEMY_CARS * 12 + # enemy cars (all fields) self.consts.NUM_FLAGS * 5 + # flags self.consts.MAX_COLLECTIBLES * 6 + # collectibles (all fields) self.consts.NUM_FLAGS + # flags_collected_mask - 7 # score, lives, is_jumping, jump_cooldown, is_on_steep_road, round_started, level + 11 # score/lives/jump state, road section start/end, round_started, level ) # Speed dividers for movement timing (indexed by speed level) self._speed_dividers = jnp.array([0, 1, 2, 4, 8, 16, 16, 16, 16]) @@ -361,12 +373,14 @@ def _get_spawn_position_for_level(self, level: chex.Array, road: chex.Array) -> @partial(jax.jit, static_argnums=(0,)) def _on_level_completed(self, state: UpNDownState) -> UpNDownState: """Advance to next level and freeze until release+press input starts it.""" - rng_key, enemy_key = jax.random.split(state.rng_key) + rng_key, enemy_key, flag_key = jax.random.split(state.rng_key, 3) next_level = (state.level + jnp.int32(1)) % jnp.int32(self.consts.LEVEL_COUNT) + next_cycle_counter = state.level_cycle_counter + jnp.int32(1) start_road = jnp.int32(0) start_segment, player_start_y, start_x = self._get_spawn_position_for_level(next_level, start_road) enemy_cars = self._initialize_enemies(enemy_key, player_start_y, next_level) + flags = self._initialize_flags(flag_key, next_level) collectibles = self._initialize_collectibles() player_car = state.player_car._replace( @@ -395,15 +409,18 @@ def _on_level_completed(self, state: UpNDownState) -> UpNDownState: movement_steps=jnp.array(0), steep_road_timer=jnp.array(0, dtype=jnp.int32), jump_slope=jnp.array(0.0, dtype=jnp.float32), + flags=flags, + flags_collected_mask=jnp.zeros(self.consts.NUM_FLAGS, dtype=jnp.bool_), collectibles=collectibles, collectible_spawn_timer=jnp.array(0, dtype=jnp.int32), enemy_cars=enemy_cars, enemy_spawn_timer=jnp.array(self.consts.ENEMY_SPAWN_INTERVAL_BASE, dtype=jnp.int32), awaiting_respawn=jnp.array(False), awaiting_round_start=jnp.array(True), - input_released=jnp.array(False), + input_released=jnp.array(True), jump_key_released=jnp.array(True), jump_total_duration=jnp.array(self.consts.JUMP_FRAMES, dtype=jnp.int32), + level_cycle_counter=next_cycle_counter, rng_key=rng_key, ) @@ -987,6 +1004,12 @@ def _collectible_step(self, state: UpNDownState, new_player_y: chex.Array, playe Tuple of (updated_collectibles, score_delta, new_spawn_timer, new_rng_key) """ rng_key, key1, key2, key3, key4 = jax.random.split(rng_key, 5) + # Salt collectible randomness with level transition count so revisiting a level + # does not reproduce the same collectible placement pattern. + key1 = jax.random.fold_in(key1, state.level_cycle_counter) + key2 = jax.random.fold_in(key2, state.level_cycle_counter) + key3 = jax.random.fold_in(key3, state.level_cycle_counter) + key4 = jax.random.fold_in(key4, state.level_cycle_counter) # Collectible spawning logic - decrement timer and spawn when ready (use jnp.where for branchless) new_collectible_timer = jnp.where( @@ -1043,6 +1066,14 @@ def _collectible_step(self, state: UpNDownState, new_player_y: chex.Array, playe spawned_color_idx = jnp.where(update_mask, color_spawn, state.collectibles.color_idx) spawned_type_id = jnp.where(update_mask, type_id_spawn, state.collectibles.type_id) spawned_active = jnp.where(update_mask, True, state.collectibles.active) + + # Keep collectible X aligned to the active level geometry. + # This prevents stale positions when switching levels with existing collectible state. + spawned_segments = jax.vmap(lambda y: self._get_road_segment(y, state.level))(spawned_y) + projected_x_road_0 = self._get_x_on_road(spawned_y, spawned_segments, corners_a, corners_y) + projected_x_road_1 = self._get_x_on_road(spawned_y, spawned_segments, corners_b, corners_y) + aligned_spawned_x = jnp.where(spawned_road == 0, projected_x_road_0, projected_x_road_1) + spawned_x = jnp.where(spawned_active, aligned_spawned_x, spawned_x) # Despawn logic - remove collectibles too far from player def check_despawn(idx): @@ -1131,6 +1162,7 @@ def _death_step(self, state: UpNDownState) -> UpNDownState: lives=lives, is_dead=is_dead, awaiting_respawn=awaiting_respawn, + input_released=jnp.where(died, jnp.array(False), state.input_released), player_car=player_car, ) @@ -1401,6 +1433,7 @@ def trigger_death(s): lives=s.lives - 1, is_dead=jnp.array(True), awaiting_respawn=jnp.array(True), + input_released=jnp.array(False), player_car=dead_car, ) @@ -1496,6 +1529,27 @@ def _initialize_collectibles(self) -> Collectible: active=jnp.zeros(self.consts.MAX_COLLECTIBLES, dtype=jnp.bool_), ) + @partial(jax.jit, static_argnums=(0,)) + def _initialize_flags(self, key: chex.Array, level: chex.Array) -> Flag: + """Initialize flags so they align with the requested level geometry.""" + # Evenly spread flags along the track with small jitter. + base_y = jnp.linspace(-900.0, -100.0, self.consts.NUM_FLAGS) + jitter = jax.random.uniform(key, shape=(self.consts.NUM_FLAGS,), minval=-40.0, maxval=40.0) + flag_y_offsets = base_y + jitter + + # Alternate roads 0/1 for variety. + flag_roads = jnp.arange(self.consts.NUM_FLAGS) % 2 + flag_segments = jax.vmap(lambda y: self._get_road_segment(y, level))(flag_y_offsets) + flag_color_indices = jnp.arange(self.consts.NUM_FLAGS) + + return Flag( + y=flag_y_offsets, + road=flag_roads, + road_segment=flag_segments, + color_idx=flag_color_indices, + collected=jnp.zeros(self.consts.NUM_FLAGS, dtype=jnp.bool_), + ) + @partial(jax.jit, static_argnums=(0,)) def _initialize_enemies(self, key: chex.Array, player_start_y: chex.Array, level: chex.Array) -> EnemyCars: """Seed the initial set of visible enemies around the player.""" @@ -1727,7 +1781,10 @@ def _enemy_step_main(self, state: UpNDownState) -> UpNDownState: @partial(jax.jit, static_argnums=(0,)) def _respawn_after_collision(self, state: UpNDownState, new_lives: chex.Array) -> UpNDownState: - """Respawn the player on a random road while preserving score and flags.""" + """Respawn the player on a random road while preserving score and flags. + + The caller is expected to gate this on a release-then-press input edge. + """ rng_key, road_key, enemy_key = jax.random.split(state.rng_key, 3) respawn_road = jax.random.randint(road_key, shape=(), minval=0, maxval=2) @@ -1776,10 +1833,11 @@ def _respawn_after_collision(self, state: UpNDownState, new_lives: chex.Array) - enemy_spawn_timer=jnp.array(self.consts.ENEMY_SPAWN_INTERVAL_BASE, dtype=jnp.int32), awaiting_respawn=jnp.array(False), awaiting_round_start=jnp.array(True), # Wait for input to start round after respawn - input_released=jnp.array(False), # Require button release before round can start + input_released=jnp.array(True), # Allow same press to clear awaiting_round_start jump_key_released=jnp.array(True), last_extra_life_score=state.last_extra_life_score, jump_total_duration=jnp.array(self.consts.JUMP_FRAMES, dtype=jnp.int32), + level_cycle_counter=state.level_cycle_counter, rng_key=rng_key, ) @@ -1864,6 +1922,7 @@ def handle_ground_collision(): lives=state.lives - 1, is_dead=jnp.array(True), awaiting_respawn=jnp.array(True), + input_released=jnp.array(False), player_car=dead_car, ) @@ -1899,28 +1958,7 @@ def _reset_jit(self, key: chex.PRNGKey) -> Tuple[UpNDownObservation, UpNDownStat initial_level = jnp.int32(0) start_road = jnp.int32(jax.random.randint(rng_key, shape=(), minval=0, maxval=2)) start_segment, player_start_y, player_start_x = self._get_spawn_position_for_level(initial_level, start_road) - - # Evenly spread flags along the track with small jitter - base_y = jnp.linspace(-900.0, -100.0, self.consts.NUM_FLAGS) - jitter = jax.random.uniform(flag_key, shape=(self.consts.NUM_FLAGS,), minval=-40.0, maxval=40.0) - flag_y_offsets = base_y + jitter - - # Alternate roads 0/1 for variety - flag_roads = jnp.arange(self.consts.NUM_FLAGS) % 2 - - # Calculate which road segment each flag is on based on Y position - flag_segments = jax.vmap(lambda y: self._get_road_segment(y, initial_level))(flag_y_offsets) - - # Each flag color index corresponds to its position (0-7) - flag_color_indices = jnp.arange(self.consts.NUM_FLAGS) - - flags = Flag( - y=flag_y_offsets, - road=flag_roads, - road_segment=flag_segments, - color_idx=flag_color_indices, - collected=jnp.zeros(self.consts.NUM_FLAGS, dtype=jnp.bool_), - ) + flags = self._initialize_flags(flag_key, initial_level) # Initialize collectibles as all inactive (will spawn dynamically with mixed types) collectibles = self._initialize_collectibles() @@ -1971,6 +2009,7 @@ def _reset_jit(self, key: chex.PRNGKey) -> Tuple[UpNDownObservation, UpNDownStat jump_key_released=jnp.array(True), last_extra_life_score=jnp.array(0, dtype=jnp.int32), jump_total_duration=jnp.array(self.consts.JUMP_FRAMES, dtype=jnp.int32), + level_cycle_counter=jnp.array(0, dtype=jnp.int32), ) initial_obs = self._get_observation(state) return initial_obs, state @@ -1992,7 +2031,10 @@ def step(self, state: UpNDownState, action: chex.Array) -> Tuple[UpNDownObservat state = state._replace(input_released=input_released) # Check if we're awaiting respawn - if so, check for input to trigger respawn - should_respawn = jnp.logical_and(state.awaiting_respawn, any_action) + should_respawn = jnp.logical_and( + jnp.logical_and(state.awaiting_respawn, any_action), + state.input_released, + ) # Respawn if player pressed any key while awaiting state = jax.lax.cond( @@ -2067,6 +2109,26 @@ def _get_observation(self, state: UpNDownState) -> UpNDownObservation: state.player_car.road_index_B, state.level, ) + + road_segment = jnp.where( + state.player_car.current_road == 0, + state.player_car.road_index_A, + state.player_car.road_index_B, + ) + corners_a, corners_b = self._get_track_corners_for_level(state.level) + corners_y = self._get_track_corners_y_for_level(state.level) + section_start_x = jnp.where( + state.player_car.current_road == 0, + corners_a[road_segment], + corners_b[road_segment], + ) + section_end_x = jnp.where( + state.player_car.current_road == 0, + corners_a[road_segment + 1], + corners_b[road_segment + 1], + ) + section_start_y = corners_y[road_segment] + section_end_y = corners_y[road_segment + 1] return UpNDownObservation( player_car=state.player_car, @@ -2079,6 +2141,10 @@ def _get_observation(self, state: UpNDownState) -> UpNDownObservation: is_jumping=jnp.int32(state.is_jumping), jump_cooldown=jnp.int32(state.jump_cooldown), is_on_steep_road=jnp.int32(is_on_steep_road), + road_section_start_x=jnp.int32(section_start_x), + road_section_start_y=jnp.int32(section_start_y), + road_section_end_x=jnp.int32(section_end_x), + road_section_end_y=jnp.int32(section_end_y), round_started=jnp.int32(state.round_started), level=jnp.int32(state.level), ) @@ -2150,7 +2216,7 @@ def obs_to_flat_array(self, obs: UpNDownObservation) -> jnp.ndarray: - Flags: NUM_FLAGS * 5 values (y, road, segment, color, collected per flag) - Collectibles: MAX_COLLECTIBLES * 6 values (y, x, road, color_idx, type, active per collectible) - Flags collected mask: NUM_FLAGS values - - Score, lives, is_jumping, jump_cooldown, is_on_steep_road, round_started, level: 7 values + - Score/lives/jump state and geometry context: 11 values """ return jnp.concatenate([ self.flatten_car(obs.player_car), @@ -2163,6 +2229,10 @@ def obs_to_flat_array(self, obs: UpNDownObservation) -> jnp.ndarray: jnp.array([obs.is_jumping], dtype=jnp.int32), jnp.array([obs.jump_cooldown], dtype=jnp.int32), jnp.array([obs.is_on_steep_road], dtype=jnp.int32), + jnp.array([obs.road_section_start_x], dtype=jnp.int32), + jnp.array([obs.road_section_start_y], dtype=jnp.int32), + jnp.array([obs.road_section_end_x], dtype=jnp.int32), + jnp.array([obs.road_section_end_y], dtype=jnp.int32), jnp.array([obs.round_started], dtype=jnp.int32), jnp.array([obs.level], dtype=jnp.int32), ]) @@ -2184,6 +2254,8 @@ def observation_space(self) -> spaces.Dict: - is_jumping: int (0 or 1) - jump_cooldown: int (0-48) - is_on_steep_road: int (0 or 1) + - road_section_start_x/y: current road section start point + - road_section_end_x/y: current road section end point - round_started: int (0 or 1) - level: int (0-2) """ @@ -2239,6 +2311,10 @@ def observation_space(self) -> spaces.Dict: "is_jumping": spaces.Box(low=0, high=1, shape=(), dtype=jnp.int32), "jump_cooldown": spaces.Box(low=0, high=48, shape=(), dtype=jnp.int32), "is_on_steep_road": spaces.Box(low=0, high=1, shape=(), dtype=jnp.int32), + "road_section_start_x": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), + "road_section_start_y": spaces.Box(low=-2000, high=0, shape=(), dtype=jnp.int32), + "road_section_end_x": spaces.Box(low=0, high=160, shape=(), dtype=jnp.int32), + "road_section_end_y": spaces.Box(low=-2000, high=0, shape=(), dtype=jnp.int32), "round_started": spaces.Box(low=0, high=1, shape=(), dtype=jnp.int32), "level": spaces.Box(low=0, high=self.consts.LEVEL_COUNT - 1, shape=(), dtype=jnp.int32), }) @@ -2260,6 +2336,20 @@ def _get_info(self, state: UpNDownState) -> UpNDownInfo: state.player_car.road_index_A, state.player_car.road_index_B, ) + corners_a, corners_b = self._get_track_corners_for_level(state.level) + corners_y = self._get_track_corners_y_for_level(state.level) + section_start_x = jnp.where( + state.player_car.current_road == 0, + corners_a[road_index], + corners_b[road_index], + ) + section_end_x = jnp.where( + state.player_car.current_road == 0, + corners_a[road_index + 1], + corners_b[road_index + 1], + ) + section_start_y = corners_y[road_index] + section_end_y = corners_y[road_index + 1] return UpNDownInfo( step_counter=jnp.int32(state.step_counter), @@ -2267,6 +2357,10 @@ def _get_info(self, state: UpNDownState) -> UpNDownInfo: movement_steps=jnp.int32(state.movement_steps), jump_slope=jnp.float32(state.jump_slope), player_road_segment=jnp.int32(road_index), + road_section_start_x=jnp.int32(section_start_x), + road_section_start_y=jnp.int32(section_start_y), + road_section_end_x=jnp.int32(section_end_x), + road_section_end_y=jnp.int32(section_end_y), ) @partial(jax.jit, static_argnums=(0,)) diff --git a/src/jaxatari/games/mods/upndown/upndown_mod_plugins.py b/src/jaxatari/games/mods/upndown/upndown_mod_plugins.py index 2ac445744..714ca6565 100644 --- a/src/jaxatari/games/mods/upndown/upndown_mod_plugins.py +++ b/src/jaxatari/games/mods/upndown/upndown_mod_plugins.py @@ -60,7 +60,10 @@ def _adjust_enemy_spawn_timer(self, state: UpNDownState, spawn_timer: chex.Array min_interval = jnp.int32(8) horizon = jnp.float32(1800.0) - progress = jnp.clip(state.movement_steps.astype(jnp.float32) / horizon, 0.0, 1.0) + in_reset_phase = jnp.logical_or(state.awaiting_round_start, state.awaiting_respawn) + effective_steps = jnp.where(in_reset_phase, jnp.int32(0), state.movement_steps) + + progress = jnp.clip(effective_steps.astype(jnp.float32) / horizon, 0.0, 1.0) decayed_interval = jnp.round( start_interval.astype(jnp.float32) - progress * (start_interval.astype(jnp.float32) - min_interval.astype(jnp.float32)) ).astype(jnp.int32) @@ -73,6 +76,8 @@ class TimeDecayCollectibleValueMod(JaxAtariInternalModPlugin): @partial(jax.jit, static_argnums=(0,)) def _collectible_score_values(self, state: UpNDownState, collectible_type_ids: chex.Array) -> chex.Array: base_scores = self._env.consts.COLLECTIBLE_SCORES[collectible_type_ids] - elapsed_decay = jnp.floor(state.movement_steps.astype(jnp.float32) / 200.0).astype(jnp.int32) + in_reset_phase = jnp.logical_or(state.awaiting_round_start, state.awaiting_respawn) + effective_steps = jnp.where(in_reset_phase, jnp.int32(0), state.movement_steps) + elapsed_decay = jnp.floor(effective_steps.astype(jnp.float32) / 200.0).astype(jnp.int32) min_scores = jnp.maximum(jnp.int32(10), base_scores // 3) return jnp.maximum(base_scores - elapsed_decay, min_scores)