diff --git a/.vscode/settings.json b/.vscode/settings.json index 9bd6a5a..a318490 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,10 +1,11 @@ { - "editor.formatOnSave": true, - "editor.defaultFormatter": "charliermarsh.ruff", - "editor.codeActionsOnSave": { - "source.fixAll.ruff": "always", - "source.organizeImports.ruff": "always" - }, - "python.analysis.typeCheckingMode": "standard", - "ruff.lineLength": 120, + "editor.formatOnSave": true, + "editor.defaultFormatter": "charliermarsh.ruff", + "editor.codeActionsOnSave": { + "source.fixAll.ruff": "always", + "source.organizeImports.ruff": "always" + }, + "python.analysis.typeCheckingMode": "standard", + "ruff.lineLength": 120, + "black-formatter.args": ["--line-length", "120"] } diff --git a/app.py b/app.py index 4ae0aca..7a049c5 100644 --- a/app.py +++ b/app.py @@ -2,25 +2,16 @@ from collections.abc import Coroutine from contextlib import asynccontextmanager from http import HTTPStatus -from importlib import import_module -from pathlib import Path from typing import Any -from fastapi import FastAPI, HTTPException +from fastapi import FastAPI, HTTPException, WebSocket from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel -from src.napta_matrix import MATRIX_SCRIPTS +from src.helpers.control import RDY, SERVER_PORT, WAIT +from src.napta_matrix import MATRIX_SCRIPTS, PlayableMatrixScript -# Import scripts -THIS_DIR = Path(__file__).resolve().parent -for file in sorted(THIS_DIR.glob("src/**/*.py")): - if file.stem != "__init__": - module_path = ".".join(file.relative_to(THIS_DIR).parts).removesuffix(".py") - import_module(module_path) - - -DEFAULT_PROGRAM = MATRIX_SCRIPTS["display_screensaver"]() +DEFAULT_PROGRAM = MATRIX_SCRIPTS["display_screensaver"].function() _main_program_task: asyncio.Task[None] @@ -29,9 +20,7 @@ @asynccontextmanager async def lifespan(app: FastAPI): global _main_program_task - _main_program_task = asyncio.create_task( - DEFAULT_PROGRAM, name=DEFAULT_PROGRAM.__name__ - ) + _main_program_task = asyncio.create_task(DEFAULT_PROGRAM, name=DEFAULT_PROGRAM.__name__) try: yield finally: @@ -55,31 +44,137 @@ def switch_program(program: Coroutine[Any, Any, None]) -> None: class ScriptResponse(BaseModel): + script_id: str script_name: str + is_playable: bool class GetScriptsResponse(BaseModel): - scripts: list[str] - current_script: str + scripts: list[ScriptResponse] + current_script: ScriptResponse @app.get("/scripts", operation_id="get_scripts") async def scripts() -> GetScriptsResponse: global _main_program_task - scripts = list(MATRIX_SCRIPTS.keys()) - current_script = _main_program_task.get_name() - return GetScriptsResponse(scripts=scripts, current_script=current_script) + scripts = [ + ScriptResponse( + script_id=script_name, + script_name=script.script_name, + is_playable=isinstance(script, PlayableMatrixScript), + ) + for script_name, script in MATRIX_SCRIPTS.items() + ] + current_script = MATRIX_SCRIPTS[_main_program_task.get_name()] + return GetScriptsResponse( + scripts=scripts, + current_script=ScriptResponse( + script_id=_main_program_task.get_name(), + script_name=current_script.script_name, + is_playable=isinstance(current_script, PlayableMatrixScript), + ), + ) class ChangeScriptRequest(BaseModel): - script: str + script_id: str @app.post("/scripts/change", operation_id="post_change_script") -async def change_script(change_script_request: ChangeScriptRequest): +async def change_script(change_script_request: ChangeScriptRequest) -> None: try: - script = MATRIX_SCRIPTS[change_script_request.script] + script = MATRIX_SCRIPTS[change_script_request.script_id].function except KeyError: - raise HTTPException(HTTPStatus.UNPROCESSABLE_ENTITY, f"Unknown program: {change_script_request.script}") + raise HTTPException( + HTTPStatus.UNPROCESSABLE_ENTITY, + f"Unknown program: {change_script_request.script_id}", + ) switch_program(script()) - return "OK" + + +class PlayableScriptResponse(BaseModel): + script_id: str + script_name: str + min_player_number: int + max_player_number: int + keys: list[str] + + +@app.get("/scripts/playable/{script_id}", operation_id="get_playable_script") +async def playable_script(script_id: str) -> PlayableScriptResponse: + try: + script = MATRIX_SCRIPTS[script_id] + except KeyError: + raise HTTPException( + HTTPStatus.UNPROCESSABLE_ENTITY, + f"Unknown program: {script_id}", + ) + + if not isinstance(script, PlayableMatrixScript): + raise HTTPException( + HTTPStatus.UNPROCESSABLE_ENTITY, + f"Program {script_id} is not a playable program", + ) + + return PlayableScriptResponse( + script_id=script_id, + script_name=script.script_name, + min_player_number=script.min_player_number, + max_player_number=script.max_player_number, + keys=[key.value for key in script.keys], + ) + + +async def _handle_choose_player( + data: dict, websocket: WebSocket, reader: asyncio.StreamReader, writer: asyncio.StreamWriter +) -> None: + player_number = data.get("data", {}).get("player_number", None) + if player_number is None: + await websocket.send_json({"error": "Missing player number"}) + return + + writer.write(f"P{player_number}".encode() + b"\n") + await writer.drain() + message = await reader.readuntil(b"\n") + if message == RDY: + await websocket.send_json({"type": "status", "data": "READY"}) + elif message == WAIT: + await websocket.send_json({"type": "status", "data": "WAITING"}) + else: + await websocket.send_json({"type": "error", "data": "Unknown message"}) + + +async def _handle_key(data: dict, websocket: WebSocket, writer: asyncio.StreamWriter) -> None: + key = data.get("data", {}).get("key", None) + if key is None: + await websocket.send_json({"type": "error", "data": "Missing key"}) + return + + if key == "UP": + key = "\x1b[A" + elif key == "DOWN": + key = "\x1b[B" + elif key == "LEFT": + key = "\x1b[D" + elif key == "RIGHT": + key = "\x1b[C" + else: + await websocket.send_json({"type": "error", "data": "Unknown key"}) + return + + writer.write(key.encode()) + + +@app.websocket("/ws") +async def websocket_endpoint(websocket: WebSocket): + reader, writer = await asyncio.open_connection(host="localhost", port=SERVER_PORT) + + await websocket.accept() + while True: + data = await websocket.receive_json() + if data.get("type") == "choose_player": + await _handle_choose_player(data, websocket, reader, writer) + elif data.get("type") == "key": + await _handle_key(data, websocket, writer) + else: + await websocket.send_json({"type": "error", "data": "Unknown message"}) diff --git a/client/package.json b/client/package.json index 2fc45b1..73ab38c 100644 --- a/client/package.json +++ b/client/package.json @@ -17,7 +17,8 @@ "@tabler/icons-react": "^3.14.0", "react": "^18.3.1", "react-dom": "^18.3.1", - "react-query": "^3.39.3" + "react-query": "^3.39.3", + "react-use-websocket": "^4.8.1" }, "devDependencies": { "@eslint/js": "^9.9.0", diff --git a/client/pnpm-lock.yaml b/client/pnpm-lock.yaml index e40451d..e1dd15b 100644 --- a/client/pnpm-lock.yaml +++ b/client/pnpm-lock.yaml @@ -29,6 +29,9 @@ importers: react-query: specifier: ^3.39.3 version: 3.39.3(react-dom@18.3.1(react@18.3.1))(react@18.3.1) + react-use-websocket: + specifier: ^4.8.1 + version: 4.8.1(react-dom@18.3.1(react@18.3.1))(react@18.3.1) devDependencies: '@eslint/js': specifier: ^9.9.0 @@ -1325,6 +1328,12 @@ packages: peerDependencies: react: ^16.8.0 || ^17.0.0 || ^18.0.0 + react-use-websocket@4.8.1: + resolution: {integrity: sha512-FTXuG5O+LFozmu1BRfrzl7UIQngECvGJmL7BHsK4TYXuVt+mCizVA8lT0hGSIF0Z0TedF7bOo1nRzOUdginhDw==} + peerDependencies: + react: '>= 18.0.0' + react-dom: '>= 18.0.0' + react@18.3.1: resolution: {integrity: sha512-wS+hAgJShR0KhEvPJArfuPVN1+Hz1t0Y6n5jLrGQbkb4urgPE/0Rve+1kMB1v/oWgHgm4WIcV+i7F2pTVj+2iQ==} engines: {node: '>=0.10.0'} @@ -2822,6 +2831,11 @@ snapshots: transitivePeerDependencies: - '@types/react' + react-use-websocket@4.8.1(react-dom@18.3.1(react@18.3.1))(react@18.3.1): + dependencies: + react: 18.3.1 + react-dom: 18.3.1(react@18.3.1) + react@18.3.1: dependencies: loose-envify: 1.4.0 diff --git a/client/src/App.tsx b/client/src/App.tsx index 3449d30..c0d1ca8 100644 --- a/client/src/App.tsx +++ b/client/src/App.tsx @@ -6,9 +6,10 @@ import { MantineProvider } from "@mantine/core"; import { client } from "./api"; import { Layout } from "./components/layout"; +import { getBaseURL } from "./api/get-base-url"; client.setConfig({ - baseUrl: `http://${window.location.hostname}:8042`, + baseUrl: getBaseURL(), }); const queryClient = new QueryClient(); diff --git a/client/src/api/get-base-url.ts b/client/src/api/get-base-url.ts new file mode 100644 index 0000000..377ee57 --- /dev/null +++ b/client/src/api/get-base-url.ts @@ -0,0 +1,2 @@ +export const getBaseURL = () => `http://${window.location.hostname}:8042`; +export const getWSBaseURL = () => `ws://${window.location.hostname}:8042`; diff --git a/client/src/api/schemas.gen.ts b/client/src/api/schemas.gen.ts index 7971004..e7fd31b 100644 --- a/client/src/api/schemas.gen.ts +++ b/client/src/api/schemas.gen.ts @@ -2,13 +2,13 @@ export const $ChangeScriptRequest = { properties: { - script: { + script_id: { type: 'string', - title: 'Script' + title: 'Script Id' } }, type: 'object', - required: ['script'], + required: ['script_id'], title: 'ChangeScriptRequest' } as const; @@ -16,14 +16,13 @@ export const $GetScriptsResponse = { properties: { scripts: { items: { - type: 'string' + '$ref': '#/components/schemas/ScriptResponse' }, type: 'array', title: 'Scripts' }, current_script: { - type: 'string', - title: 'Current Script' + '$ref': '#/components/schemas/ScriptResponse' } }, type: 'object', @@ -45,6 +44,57 @@ export const $HTTPValidationError = { title: 'HTTPValidationError' } as const; +export const $PlayableScriptResponse = { + properties: { + script_id: { + type: 'string', + title: 'Script Id' + }, + script_name: { + type: 'string', + title: 'Script Name' + }, + min_player_number: { + type: 'integer', + title: 'Min Player Number' + }, + max_player_number: { + type: 'integer', + title: 'Max Player Number' + }, + keys: { + items: { + type: 'string' + }, + type: 'array', + title: 'Keys' + } + }, + type: 'object', + required: ['script_id', 'script_name', 'min_player_number', 'max_player_number', 'keys'], + title: 'PlayableScriptResponse' +} as const; + +export const $ScriptResponse = { + properties: { + script_id: { + type: 'string', + title: 'Script Id' + }, + script_name: { + type: 'string', + title: 'Script Name' + }, + is_playable: { + type: 'boolean', + title: 'Is Playable' + } + }, + type: 'object', + required: ['script_id', 'script_name', 'is_playable'], + title: 'ScriptResponse' +} as const; + export const $ValidationError = { properties: { loc: { diff --git a/client/src/api/services.gen.ts b/client/src/api/services.gen.ts index 605f6c7..e479623 100644 --- a/client/src/api/services.gen.ts +++ b/client/src/api/services.gen.ts @@ -1,7 +1,7 @@ // This file is auto-generated by @hey-api/openapi-ts import { createClient, createConfig, type Options } from '@hey-api/client-fetch'; -import type { GetScriptsError, GetScriptsResponse2, PostChangeScriptData, PostChangeScriptError, PostChangeScriptResponse } from './types.gen'; +import type { GetScriptsError, GetScriptsResponse2, PostChangeScriptData, PostChangeScriptError, PostChangeScriptResponse, GetPlayableScriptData, GetPlayableScriptError, GetPlayableScriptResponse } from './types.gen'; export const client = createClient(createConfig()); @@ -19,4 +19,12 @@ export const getScripts = (options?: Optio export const postChangeScript = (options: Options) => { return (options?.client ?? client).post({ ...options, url: '/scripts/change' +}); }; + +/** + * Playable Script + */ +export const getPlayableScript = (options: Options) => { return (options?.client ?? client).get({ + ...options, + url: '/scripts/playable/{script_id}' }); }; \ No newline at end of file diff --git a/client/src/api/types.gen.ts b/client/src/api/types.gen.ts index 48de6d9..64748da 100644 --- a/client/src/api/types.gen.ts +++ b/client/src/api/types.gen.ts @@ -1,18 +1,32 @@ // This file is auto-generated by @hey-api/openapi-ts export type ChangeScriptRequest = { - script: string; + script_id: string; }; export type GetScriptsResponse = { - scripts: Array<(string)>; - current_script: string; + scripts: Array; + current_script: ScriptResponse; }; export type HTTPValidationError = { detail?: Array; }; +export type PlayableScriptResponse = { + script_id: string; + script_name: string; + min_player_number: number; + max_player_number: number; + keys: Array<(string)>; +}; + +export type ScriptResponse = { + script_id: string; + script_name: string; + is_playable: boolean; +}; + export type ValidationError = { loc: Array<(string | number)>; msg: string; @@ -29,4 +43,14 @@ export type PostChangeScriptData = { export type PostChangeScriptResponse = (unknown); -export type PostChangeScriptError = (HTTPValidationError); \ No newline at end of file +export type PostChangeScriptError = (HTTPValidationError); + +export type GetPlayableScriptData = { + path: { + script_id: string; + }; +}; + +export type GetPlayableScriptResponse = (PlayableScriptResponse); + +export type GetPlayableScriptError = (HTTPValidationError); \ No newline at end of file diff --git a/client/src/pages/scripts.tsx b/client/src/pages/scripts.tsx deleted file mode 100644 index 74044f1..0000000 --- a/client/src/pages/scripts.tsx +++ /dev/null @@ -1,58 +0,0 @@ -import { useMutation, useQuery, useQueryClient } from "react-query"; -import { useCallback } from "react"; -import { getScripts, postChangeScript } from "../api"; -import { Button, Flex, Loader, Title } from "@mantine/core"; - -const useScripts = () => { - const queryClient = useQueryClient(); - const { data: scripts, isLoading } = useQuery({ - queryKey: ["scripts"], - queryFn: getScripts, - }); - - const { mutate } = useMutation({ - mutationFn: (script: string) => postChangeScript({ body: { script } }), - onSuccess: () => { - queryClient.invalidateQueries("scripts"); - }, - }); - - const changeScript = useCallback( - (script: string) => { - mutate(script); - }, - [mutate] - ); - - return { scripts, isLoading, changeScript }; -}; - -export const Scripts = () => { - const { scripts, isLoading, changeScript } = useScripts(); - if (isLoading || !scripts) return ; - - return ( - <> - Scripts - - {scripts.data?.scripts.map((script) => ( - - ))} - - - ); -}; diff --git a/client/src/pages/scripts/components/playable-script.tsx b/client/src/pages/scripts/components/playable-script.tsx new file mode 100644 index 0000000..b1c346f --- /dev/null +++ b/client/src/pages/scripts/components/playable-script.tsx @@ -0,0 +1,170 @@ +import { Button, Flex, Input, Kbd, Loader, Title } from "@mantine/core"; +import { + PlayableScriptKey, + PlayableScriptStatus, + usePlayableScript, +} from "../hooks/use-playable-script"; +import { useState } from "react"; + +export const PlayableScript = ({ + scriptId, + scriptName, +}: { + scriptId: string; + scriptName: string; +}) => { + const { + isLoading, + minPlayers, + choosePlayer, + maxPlayers, + status, + sendKey, + keys, + } = usePlayableScript({ + scriptId, + }); + + return ( + <> + + {scriptName.toUpperCase()} + + {isLoading ? ( + + ) : ( + <> + + + + )} + + ); +}; + +const PlayableScriptsDetails = ({ + minPlayers, + maxPlayers, + status, +}: { + minPlayers: number; + maxPlayers: number; + status: string; +}) => { + return ( +
+
Min players: {minPlayers}
+
Max players: {maxPlayers}
+
Status: {status}
+
+ ); +}; + +const PlayableScriptControl = ({ + maxPlayers, + sendKey, + status, + keys, + choosePlayer, +}: { + maxPlayers: number; + status: PlayableScriptStatus; + keys: PlayableScriptKey[]; + sendKey: (key: PlayableScriptKey) => void; + choosePlayer: (playerNumber: number) => void; +}) => { + return ( +
+ + {status === "READY" && } +
+ ); +}; + +const PlayerSelection = ({ + maxPlayers, + choosePlayer, +}: { + maxPlayers: number; + choosePlayer: (playerNumber: number) => void; +}) => { + return ( + <> + + Select a player + + + {Array.from({ length: maxPlayers }).map((_, index) => ( + + ))} + + + ); +}; + +const PlayerControl = ({ + sendKey, + keys, +}: { + sendKey: (key: PlayableScriptKey) => void; + keys: PlayableScriptKey[]; +}) => { + return ( + <> + + Control the player + + + + {keys.map((key) => ( + {key} + ))} + + { + event.preventDefault(); + event.stopPropagation(); + if (event.key === "ArrowUp") { + sendKey("UP"); + } + if (event.key === "ArrowDown") { + sendKey("DOWN"); + } + if (event.key === "ArrowLeft") { + sendKey("LEFT"); + } + if (event.key === "ArrowRight") { + sendKey("RIGHT"); + } + }} + placeholder="Press keys here to control" + /> + + + ); +}; diff --git a/client/src/pages/scripts/hooks/use-playable-script.ts b/client/src/pages/scripts/hooks/use-playable-script.ts new file mode 100644 index 0000000..4ec2521 --- /dev/null +++ b/client/src/pages/scripts/hooks/use-playable-script.ts @@ -0,0 +1,63 @@ +import { useQuery } from "react-query"; +import { getPlayableScript } from "../../../api"; +import { useCallback, useMemo } from "react"; +import { getWSBaseURL } from "../../../api/get-base-url"; +import useWebSocket, { ReadyState } from "react-use-websocket"; + +export type PlayableScriptStatus = "READY" | "WAITING"; + +export type PlayableScriptKey = "UP" | "DOWN" | "LEFT" | "RIGHT"; +type WSResponse = + | { + type: "error"; + data: string; + } + | { + type: "status"; + data: PlayableScriptStatus; + }; + +export const usePlayableScript = ({ scriptId }: { scriptId: string }) => { + const socketURL = useMemo(() => { + return `${getWSBaseURL()}/ws`; + }, []); + + const { readyState, sendJsonMessage, lastJsonMessage } = + useWebSocket(socketURL); + + const { data, isLoading } = useQuery({ + queryKey: ["playable-script", { scriptId }], + queryFn: () => getPlayableScript({ path: { script_id: scriptId } }), + }); + + const choosePlayer = useCallback( + (playerNumber: number) => { + sendJsonMessage({ + type: "choose_player", + data: { player_number: playerNumber }, + }); + }, + [sendJsonMessage] + ); + + const sendKey = useCallback( + (key: PlayableScriptKey) => { + sendJsonMessage({ + type: "key", + data: { key }, + }); + }, + [sendJsonMessage] + ); + + return { + isLoading: isLoading || readyState !== ReadyState.OPEN, + keys: data?.data?.keys, + minPlayers: data?.data?.min_player_number, + maxPlayers: data?.data?.max_player_number, + status: + lastJsonMessage?.type === "status" ? lastJsonMessage.data : "WAITING", + choosePlayer, + sendKey, + }; +}; diff --git a/client/src/pages/scripts/hooks/use-scripts.ts b/client/src/pages/scripts/hooks/use-scripts.ts new file mode 100644 index 0000000..44a7f47 --- /dev/null +++ b/client/src/pages/scripts/hooks/use-scripts.ts @@ -0,0 +1,28 @@ +import { useCallback } from "react"; +import { useQueryClient, useQuery, useMutation } from "react-query"; +import { getScripts, postChangeScript } from "../../../api"; + +export const useScripts = () => { + const queryClient = useQueryClient(); + const { data: scripts, isLoading } = useQuery({ + queryKey: ["scripts"], + queryFn: getScripts, + }); + + const { mutate } = useMutation({ + mutationFn: (scriptId: string) => + postChangeScript({ body: { script_id: scriptId } }), + onSuccess: () => { + queryClient.invalidateQueries("scripts"); + }, + }); + + const changeScript = useCallback( + (script: string) => { + mutate(script); + }, + [mutate] + ); + + return { scripts, isLoading, changeScript }; +}; diff --git a/client/src/pages/scripts/index.tsx b/client/src/pages/scripts/index.tsx new file mode 100644 index 0000000..00fd038 --- /dev/null +++ b/client/src/pages/scripts/index.tsx @@ -0,0 +1,41 @@ +import { Loader, Title, Flex, Button } from "@mantine/core"; +import { PlayableScript } from "./components/playable-script"; +import { useScripts } from "./hooks/use-scripts"; + +export const Scripts = () => { + const { scripts, isLoading, changeScript } = useScripts(); + if (isLoading || !scripts || !scripts.data) return ; + + return ( + <> + Scripts + + {scripts.data.scripts.map((script) => ( + + ))} + + {scripts.data.current_script.is_playable && ( + + )} + + ); +}; diff --git a/co b/co deleted file mode 100755 index 9b78e46..0000000 --- a/co +++ /dev/null @@ -1,5 +0,0 @@ -#!/bin/bash -# Entry point for SSH connections to connect to a control server. - -export PYTHONPATH=".":$PYTHONPATH -python3 src/client_connect.py diff --git a/local-play.sh b/local-play.sh deleted file mode 100755 index ea31e0c..0000000 --- a/local-play.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/bin/bash -# Connect to a game control server running locally. - -export PYTHONPATH=".":$PYTHONPATH -python src/client_connect.py --host=127.0.0.1 diff --git a/play.sh b/play.sh deleted file mode 100755 index c8bbe0e..0000000 --- a/play.sh +++ /dev/null @@ -1,5 +0,0 @@ -#!/bin/bash -# Connect to a game control server running on the Raspberry Pi. - -export PYTHONPATH=".":$PYTHONPATH -python src/client_connect.py diff --git a/requirements.txt b/requirements.txt index 307a0c3..325724b 100644 --- a/requirements.txt +++ b/requirements.txt @@ -3,3 +3,5 @@ requests==2.32.3 RGBMatrixEmulator==0.11.6 typing_extensions==4.12.2 uvicorn==0.30.6 +websockets==13.0.1 + diff --git a/src/client_connect.py b/src/client_connect.py deleted file mode 100644 index 6a04988..0000000 --- a/src/client_connect.py +++ /dev/null @@ -1,11 +0,0 @@ -import asyncio -from argparse import ArgumentParser - -from src.helpers.control import connect_to_server - -if __name__ == "__main__": - parser = ArgumentParser() - parser.add_argument("--host", default="192.168.128.175") - args = parser.parse_args() - - asyncio.run(connect_to_server(args.host)) diff --git a/src/display_pong.py b/src/display_pong.py index 21f5838..8c638a8 100644 --- a/src/display_pong.py +++ b/src/display_pong.py @@ -9,7 +9,7 @@ from src.helpers.draw import pattern_to_points from src.helpers.fullscreen_message import fullscreen_message from src.helpers.napta_colors import NaptaColor -from src.napta_matrix import RGBMatrix, matrix_script +from src.napta_matrix import KeyboardKeys, RGBMatrix, playable_matrix_script BOARD_SIZE = 64 PADDLE_SIZE = 10 @@ -106,7 +106,11 @@ def _score_points(score: int, player: Literal[1, 2, 3, 4]) -> set[tuple[int, int } -@matrix_script +@playable_matrix_script( + min_player_number=2, + max_player_number=4, + keys=[KeyboardKeys.UP, KeyboardKeys.DOWN, KeyboardKeys.LEFT, KeyboardKeys.RIGHT], +) async def display_pong(matrix: RGBMatrix) -> None: def draw_point(pix: tuple[int, int], color: tuple[int, int, int]) -> None: matrix.SetPixel(*pix, *color) @@ -321,14 +325,7 @@ def goal(player: Literal[0, 1, 2, 3, 4], player_looser: Literal[0, 1, 2, 3, 4]) await fullscreen_message(matrix, ["Starting", "Pong game", "server..."]) on_started = fullscreen_message( matrix, - [ - "Connect to", - "play Pong:", - "./play.sh", - "in the repo", - "(Web client", - "incoming)", - ], + ["Waiting for", "players..."], ) client_names = ["P1", "P2", "P3", "P4"] diff --git a/src/display_snake.py b/src/display_snake.py index 8809496..5ddabbe 100644 --- a/src/display_snake.py +++ b/src/display_snake.py @@ -39,7 +39,9 @@ def get_dir(current_dir: Dir, input: bytes) -> Dir: @matrix_script async def display_snake(matrix: RGBMatrix) -> None: - snake = deque(((20 + i) % BOARD_SIZE, 40) for i in range(INITIAL_SNAKE_LEN, 0, -1)) # Head to queue + snake = deque( + ((20 + i) % BOARD_SIZE, 40) for i in range(INITIAL_SNAKE_LEN, 0, -1) + ) # Head to queue def get_next_apple() -> tuple[int, int]: while (maybe_apple := (randrange(BOARD_SIZE), randrange(BOARD_SIZE))) in snake: @@ -93,7 +95,15 @@ def update_game() -> None: await fullscreen_message(matrix, ["Starting", "Snake game", "server..."]) on_started = fullscreen_message( - matrix, ["Connect to", "play Snake:", "./play.sh", "in the repo", "(Web client", "incoming)"] + matrix, + [ + "Connect to", + "play Snake:", + "./play.sh", + "in the repo", + "(Web client", + "incoming)", + ], ) async with control_server(client_names=["P"], on_started=on_started) as server: @@ -103,7 +113,9 @@ def update_game() -> None: while True: t_start = time.time() try: - input = await asyncio.wait_for(server.clients["P"].read(32), timeout=timeout) + input = await asyncio.wait_for( + server.clients["P"].read(32), timeout=timeout + ) dir = get_dir(dir, input) except asyncio.TimeoutError: pass diff --git a/src/helpers/control.py b/src/helpers/control.py index ad3ff04..2315b2b 100644 --- a/src/helpers/control.py +++ b/src/helpers/control.py @@ -4,17 +4,24 @@ from contextlib import asynccontextmanager from typing import Optional -from src.helpers import ainput +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) SERVER_PORT = 4422 -INP = b"INP\n" RDY = b"RDY\n" +WAIT = b"WAIT\n" +WRONG = b"WRG\n" class ControlServer: - def __init__(self) -> None: + def __init__(self, *, min_clients: int) -> None: self.clients = dict[str, asyncio.StreamReader]() + self.min_clients = min_clients + + def can_start(self) -> bool: + return len(self.clients) >= self.min_clients @asynccontextmanager @@ -25,59 +32,34 @@ async def control_server( ) -> AsyncIterator[ControlServer]: if min_clients is None: min_clients = len(client_names) + _client_names = [name.encode() for name in client_names] - server = ControlServer() + server = ControlServer(min_clients=min_clients) - async def client_connected( - reader: asyncio.StreamReader, writer: asyncio.StreamWriter - ): + async def client_connected(reader: asyncio.StreamReader, writer: asyncio.StreamWriter): if len(client_names) == 1: [client_name] = client_names else: while True: - writer.write(b"Who are you? (choices: %s)\n" % b", ".join(_client_names)) - writer.write(INP) - await writer.drain() proposal = (await reader.readuntil(b"\n")).strip() if proposal in _client_names: client_name = proposal.decode() break - writer.write(b"Wrong answer...\n\n") + writer.write(WRONG) - writer.write(RDY) - await writer.drain() server.clients[client_name] = reader - logging.info(f"Creating TCP server on port {SERVER_PORT}...") + if server.can_start(): + writer.write(RDY) + else: + writer.write(WAIT) + await writer.drain() + + logger.info(f"Creating TCP server on port {SERVER_PORT}...") async with await asyncio.start_server(client_connected, host="0.0.0.0", port=SERVER_PORT): - logging.info("Server ready!") if on_started: await on_started - while len(server.clients) < min_clients: + while not server.can_start(): await asyncio.sleep(1) yield server - - -async def connect_to_server(host: str) -> None: - print("Connecting...") - reader, writer = await asyncio.open_connection(host=host, port=SERVER_PORT) - print("Connected!") - while True: - message = await reader.readuntil(b"\n") - if message == RDY: - break - elif message == INP: - inp = input(">>> ") - writer.write(inp.encode() + b"\n") - await writer.drain() - else: - print(message.decode(), end="") - - print("Ready!") - - with ainput.capture_terminal() as get_input: - while True: - if input_ := get_input(): - writer.write(input_) - await asyncio.sleep(0.01) diff --git a/src/napta_matrix.py b/src/napta_matrix.py index 8c5bf06..bda2569 100644 --- a/src/napta_matrix.py +++ b/src/napta_matrix.py @@ -1,9 +1,10 @@ import asyncio +import enum import logging import os from collections.abc import Callable, Coroutine from functools import lru_cache, wraps -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, NamedTuple, cast from typing_extensions import Concatenate, ParamSpec @@ -42,7 +43,50 @@ def _get_matrix() -> RGBMatrix: return RGBMatrix(options=options) -MATRIX_SCRIPTS = dict[str, Callable[..., Coroutine[Any, Any, None]]]() +class MatrixScript(NamedTuple): + function: Callable[..., Coroutine[Any, Any, None]] + script_name: str + + +class KeyboardKeys(enum.Enum): + UP = "UP" + DOWN = "DOWN" + LEFT = "LEFT" + RIGHT = "RIGHT" + + +class PlayableMatrixScript(NamedTuple): + function: Callable[..., Coroutine[Any, Any, None]] + script_name: str + min_player_number: int + max_player_number: int + keys: list[KeyboardKeys] + + +MATRIX_SCRIPTS = dict[str, MatrixScript | PlayableMatrixScript]() + + +async def _handle_matrix_script_exception(matrix: RGBMatrix, function_name: str) -> None: + from src.display_screensaver import display_screensaver + from src.helpers.fullscreen_message import fullscreen_message + + logging.exception(f"Fatal error in program {function_name!r}: ", exc_info=True) + program_name = function_name.removeprefix("display_") + await fullscreen_message( + matrix, + [ + "Fatal error", + "in program", + program_name, + "", + "Restarting", + "in a few", + "seconds...", + ], + color=cast(tuple[int, int, int], NaptaColor.BITTERSWEET), + ) + await asyncio.sleep(5) + await asyncio.create_task(display_screensaver()) def matrix_script( @@ -55,31 +99,47 @@ async def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: try: await function(matrix, *args, **kwargs) except Exception: - from src.display_screensaver import display_screensaver - from src.helpers.fullscreen_message import fullscreen_message - - logging.exception(f"Fatal error in program {function.__name__!r}: ", exc_info=True) - program_name = function.__name__.removeprefix("display_") - await fullscreen_message( - matrix, - [ - "Fatal error", - "in program", - program_name, - "", - "Restarting", - "in a few", - "seconds...", - ], - color=NaptaColor.BITTERSWEET, - ) - await asyncio.sleep(5) - await asyncio.create_task(display_screensaver()) + await _handle_matrix_script_exception(matrix, function.__name__) raise - MATRIX_SCRIPTS[function.__name__] = wrapper + MATRIX_SCRIPTS[function.__name__] = MatrixScript( + function=wrapper, script_name=function.__name__.replace("display_", "") + ) return wrapper -__all__ = ["RGBMatrix", "RGBMatrixOptions", "graphics", "matrix_script"] +def playable_matrix_script(*, min_player_number: int, max_player_number: int, keys: list[KeyboardKeys]): + def matrix_script( + function: Callable[Concatenate[RGBMatrix, _P], Coroutine[Any, Any, None]], + ) -> Callable[_P, Coroutine[Any, Any, None]]: + @wraps(function) + async def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> None: + matrix = _get_matrix() + matrix.Clear() + try: + await function(matrix, *args, **kwargs) + except Exception: + await _handle_matrix_script_exception(matrix, function.__name__) + raise + + MATRIX_SCRIPTS[function.__name__] = PlayableMatrixScript( + function=wrapper, + script_name=function.__name__.replace("display_", ""), + min_player_number=min_player_number, + max_player_number=max_player_number, + keys=keys, + ) + + return wrapper + + return matrix_script + + +__all__ = [ + "RGBMatrix", + "RGBMatrixOptions", + "graphics", + "matrix_script", + "playable_matrix_script", +]