diff --git a/client/.gitignore b/client/.gitignore new file mode 100644 index 0000000..15bdabd --- /dev/null +++ b/client/.gitignore @@ -0,0 +1,11 @@ +# Generated proto stubs (regenerate with generate_proto.sh) +nightshade_client/grpc_stubs/nightshade_pb2.py +nightshade_client/grpc_stubs/nightshade_pb2.pyi +nightshade_client/grpc_stubs/nightshade_pb2_grpc.py + +# Build artifacts +*.egg-info/ +__pycache__/ +*.pyc +dist/ +build/ diff --git a/client/CLAUDE.md b/client/CLAUDE.md new file mode 100644 index 0000000..b59ed2e --- /dev/null +++ b/client/CLAUDE.md @@ -0,0 +1,498 @@ +# Nightshade Client SDK — Part B Implementation Guide + +Auto-loaded by Claude Code. Everything you need to implement the Windows capture and +processing backends is in this file. + +--- + +## Architecture Overview + +``` +nightshade_client/ +├── connection/ # Part A - gRPC channel, auth, health check +├── session/ # Part A - session lifecycle (start/end), monitor +├── streams/ # Part A - chat, voice, evidence stream managers +├── pipeline/ +│ ├── base.py # ABCs + dataclasses + Mock implementations (interface contract) +│ ├── buffer.py # RingBuffer (dashcam circular buffer) +│ └── orchestrator.py # Wires everything together (main async loop) +├── capture/ # Part B stubs - IMPLEMENT HERE +│ ├── screen.py → MSSScreenCapture +│ ├── audio.py → SoundDeviceCapture +│ └── video.py → NVENCVideoRecorder +├── processing/ # Part B stubs - IMPLEMENT HERE +│ ├── ocr.py → PaddleOCREngine +│ ├── stt.py → WhisperSTT +│ └── diarize.py → NeMoDiarizer +├── hud/ # Part B stub - IMPLEMENT HERE +│ └── overlay.py → DearPyGuiOverlay +├── config.py # ClientSettings (pydantic-settings, NIGHTSHADE_ env prefix) +└── __main__.py # CLI entry point +``` + +### Part A (complete - do not modify) +- gRPC channel management and TLS auth (`connection/`) +- Session lifecycle: StartSession / EndSession RPCs (`session/`) +- Stream managers: ChatStreamManager, VoiceStreamManager, EvidenceStreamManager (`streams/`) +- Pipeline orchestrator with capture loop and audio loop (`pipeline/orchestrator.py`) +- RingBuffer dashcam (saves JSON marker; Part B encodes the actual video) (`pipeline/buffer.py`) +- Mock implementations for all ABCs (in `pipeline/base.py`) +- CLI (`__main__.py`) + +### Part B (to implement — this session) +Real Windows backends: screen capture, audio capture, NVENC video encoding, +PaddleOCR, faster-whisper STT, NeMo diarization, DearPyGui HUD overlay. + +Each Part B class lives in its stub file, implements the ABC from `pipeline/base.py`, +and is injected into `PipelineOrchestrator` by the caller (or CLI code). + +--- + +## Abstract Base Classes — Full Interface Contract + +These are copied verbatim from `pipeline/base.py`. Implement every `@abstractmethod`. + +### Dataclasses + +```python +@dataclass +class CapturedFrame: + """A single screen capture frame.""" + image: np.ndarray # HxWxC uint8 BGR + timestamp: float # unix seconds + width: int = 0 + height: int = 0 + + def __post_init__(self): + if self.width == 0 and self.image is not None: + self.height, self.width = self.image.shape[:2] + + +@dataclass +class CapturedAudio: + """A chunk of captured audio.""" + samples: np.ndarray # float32 mono PCM + sample_rate: int = 16000 + timestamp: float = 0.0 + duration_s: float = 0.0 + + +@dataclass +class OCRResult: + """OCR output from a single frame.""" + texts: list[str] = field(default_factory=list) + confidences: list[float] = field(default_factory=list) + boxes: list[tuple[int, int, int, int]] = field(default_factory=list) # x,y,w,h + timestamp: float = 0.0 + frame_width: int = 0 + frame_height: int = 0 + + +@dataclass +class TranscriptSegment: + """A single speech-to-text segment.""" + text: str = "" + speaker_label: str = "" + confidence: float = 0.0 + diarization_confidence: float = 0.0 + start_time: float = 0.0 + end_time: float = 0.0 + timestamp: float = 0.0 +``` + +### ScreenCapture ABC + +```python +class ScreenCapture(ABC): + """Captures screen frames.""" + + @abstractmethod + async def start(self) -> None: + """Initialize capture resources.""" + + @abstractmethod + async def capture_frame(self, region: Optional[tuple[int, int, int, int]] = None) -> CapturedFrame: + """Capture a single frame, optionally from a specific region (x,y,w,h).""" + + @abstractmethod + async def stop(self) -> None: + """Release capture resources.""" +``` + +### AudioCapture ABC + +```python +class AudioCapture(ABC): + """Captures audio from a device.""" + + @abstractmethod + async def start(self, device: Optional[str] = None, sample_rate: int = 16000) -> None: + """Start audio capture from the specified device.""" + + @abstractmethod + async def read_chunk(self, duration_s: float = 5.0) -> CapturedAudio: + """Read a chunk of audio of the given duration.""" + + @abstractmethod + async def stop(self) -> None: + """Stop audio capture and release resources.""" +``` + +### VideoRecorder ABC + +```python +class VideoRecorder(ABC): + """Records video (e.g., NVENC-accelerated).""" + + @abstractmethod + async def start(self, output_path: Path, fps: int = 30) -> None: + """Start recording to the given path.""" + + @abstractmethod + async def write_frame(self, frame: CapturedFrame) -> None: + """Write a frame to the recording.""" + + @abstractmethod + async def stop(self) -> Path: + """Stop recording, finalize file, return path.""" +``` + +### OCREngine ABC + +```python +class OCREngine(ABC): + """Runs OCR on captured frames.""" + + @abstractmethod + async def initialize(self, model: str = "default") -> None: + """Load OCR model.""" + + @abstractmethod + async def extract_text(self, frame: CapturedFrame) -> OCRResult: + """Extract text from a frame.""" + + @abstractmethod + async def shutdown(self) -> None: + """Unload model and free resources.""" +``` + +### SpeechToText ABC + +```python +class SpeechToText(ABC): + """Transcribes audio to text.""" + + @abstractmethod + async def initialize(self, model: str = "base") -> None: + """Load STT model.""" + + @abstractmethod + async def transcribe(self, audio: CapturedAudio) -> list[TranscriptSegment]: + """Transcribe an audio chunk into segments.""" + + @abstractmethod + async def shutdown(self) -> None: + """Unload model and free resources.""" +``` + +### SpeakerDiarizer ABC + +```python +class SpeakerDiarizer(ABC): + """Identifies speaker segments in audio.""" + + @abstractmethod + async def initialize(self) -> None: + """Load diarization model.""" + + @abstractmethod + async def diarize(self, audio: CapturedAudio) -> list[TranscriptSegment]: + """Assign speaker labels to audio segments.""" + + @abstractmethod + async def shutdown(self) -> None: + """Unload model and free resources.""" +``` + +### HUDOverlay ABC + +```python +class HUDOverlay(ABC): + """Heads-up display overlay for operator feedback.""" + + @abstractmethod + async def start(self) -> None: + """Initialize and show the HUD window.""" + + @abstractmethod + async def update_status(self, status: str) -> None: + """Update the main status text.""" + + @abstractmethod + async def show_alert(self, alert_type: str, detail: str) -> None: + """Display an alert notification.""" + + @abstractmethod + async def update_stats( + self, + messages: int = 0, + alerts: int = 0, + subjects: int = 0, + ) -> None: + """Update session statistics display.""" + + @abstractmethod + async def stop(self) -> None: + """Close the HUD window.""" +``` + +--- + +## Proto Contract — Messages the Client Sends + +### ChatMessageUpload +Sent by `ChatStreamManager` for each OCR-detected chat line. + +| Field | Type | Notes | +|-------|------|-------| +| `session_id` | str | Active session UUID | +| `username` | str | Parsed from OCR — "[username]: msg" format | +| `user_id` | str | Roblox user ID (empty string if unknown) | +| `message` | str | Chat message text | +| `source` | str | `"game_chat"` (from OCR) or `"voice"` | +| `confidence` | float | OCR confidence 0.0-1.0 | +| `risk_flags` | repeated str | Leave empty; server classifies | +| `timestamp` | int | Unix milliseconds | + +### VoiceTranscriptUpload +Sent by `VoiceStreamManager` for each STT segment. + +| Field | Type | Notes | +|-------|------|-------| +| `session_id` | str | Active session UUID | +| `speaker_label` | str | e.g. `"speaker_0"` from NeMo diarization | +| `user_id_guess` | str | Leave empty; server correlates | +| `text` | str | Transcribed speech text | +| `confidence` | float | Whisper confidence 0.0-1.0 | +| `diarization_confidence` | float | NeMo diarization confidence | +| `risk_flags` | repeated str | Leave empty | +| `timestamp` | int | Unix milliseconds | + +### EvidenceChunk +Sent by `EvidenceStreamManager` when uploading a dashcam recording. + +| Field | Type | Notes | +|-------|------|-------| +| `session_id` | str | Active session UUID | +| `filename` | str | Basename of the file, e.g. `"dashcam_alert_xyz.mp4"` | +| `chunk_data` | bytes | Raw file bytes (chunked to `evidence_chunk_size`, default 1 MB) | +| `chunk_index` | int | 0-based chunk index | +| `total_chunks` | int | Total number of chunks for this file | +| `sha256_hash` | str | Full-file SHA-256 hex (all chunks carry same hash) | +| `file_type` | str | `"recording"` for dashcam video | + +### ClientEvent (Heartbeat) +Sent periodically by `SessionMonitor` to keep the session alive. + +| Field | Type | Notes | +|-------|------|-------| +| `session_id` | str | Active session UUID | +| `heartbeat.timestamp` | int | Unix milliseconds | +| `heartbeat.messages_buffered` | int | Current chat queue depth | + +--- + +## Part B Implementation Checklist + +Implement each class in the listed file. Each file already has the stub imports. +Add your class below the existing imports. Do not break the `__all__` exports. + +| File | Class to implement | ABC | Library | +|------|-------------------|-----|---------| +| `capture/screen.py` | `MSSScreenCapture` | `ScreenCapture` | `mss` | +| `capture/audio.py` | `SoundDeviceCapture` | `AudioCapture` | `sounddevice` | +| `capture/video.py` | `NVENCVideoRecorder` | `VideoRecorder` | `ffmpeg` subprocess (NVENC encoder) | +| `processing/ocr.py` | `PaddleOCREngine` | `OCREngine` | `paddleocr` | +| `processing/stt.py` | `WhisperSTT` | `SpeechToText` | `faster_whisper` | +| `processing/diarize.py` | `NeMoDiarizer` | `SpeakerDiarizer` | `nemo_toolkit` | +| `hud/overlay.py` | `DearPyGuiOverlay` | `HUDOverlay` | `dearpygui` | + +### Implementation Notes Per Class + +**MSSScreenCapture** (`capture/screen.py`) +- `mss.mss()` context manager — keep it open in `start()`, close in `stop()` +- `capture_frame(region)`: region is `(x, y, w, h)`. Convert mss bbox `{"top": y, "left": x, "width": w, "height": h}` +- mss returns BGRA; strip alpha channel → BGR uint8 `np.ndarray` for `CapturedFrame.image` +- If region is None, capture the primary monitor (`sct.monitors[1]`) +- `timestamp = time.time()` immediately after the grab + +**SoundDeviceCapture** (`capture/audio.py`) +- Use `sounddevice.InputStream` in `start()`, keep open, close in `stop()` +- `read_chunk(duration_s)`: read `int(sample_rate * duration_s)` frames, blocking +- Output must be float32 mono; if stereo input, average channels +- `device` param matches `sounddevice` device name or index (from `NIGHTSHADE_AUDIO_DEVICE`) + +**NVENCVideoRecorder** (`capture/video.py`) +- Use `ffmpeg` subprocess with NVENC: `-c:v h264_nvenc` +- Input: raw BGR frames via stdin pipe (`-f rawvideo -pix_fmt bgr24 -s WxH -r FPS -i pipe:0`) +- Output: H.264 MP4 to `output_path` +- `write_frame(frame)`: write `frame.image.tobytes()` to the ffmpeg stdin pipe +- `stop()`: close stdin pipe, wait for subprocess, return output_path +- Detect frame dimensions from the first `write_frame()` call; start ffmpeg lazily on first frame + +**PaddleOCREngine** (`processing/ocr.py`) +- `initialize()`: `PaddleOCR(use_angle_cls=True, lang='en', use_gpu=True)`; `model` param is unused for now +- `extract_text(frame)`: call `ocr.ocr(frame.image, cls=True)` +- PaddleOCR returns `[[[box_coords, (text, confidence)], ...]]` — flatten the outer list +- Each box_coords is 4 corner points; compute bounding rect `(x, y, w, h)` from min/max +- Filter out results with confidence below 0.5 +- Set `OCRResult.timestamp = frame.timestamp`, `frame_width = frame.width`, `frame_height = frame.height` + +**WhisperSTT** (`processing/stt.py`) +- `initialize(model)`: `WhisperModel(model, device="cuda", compute_type="float16")` +- `transcribe(audio)`: `model.transcribe(audio.samples, beam_size=5)` +- faster-whisper returns `(segments_generator, info)` — iterate segments +- Each segment has `.text`, `.start`, `.end`, `.avg_logprob` (confidence proxy) +- Map to `TranscriptSegment(text, start_time=seg.start, end_time=seg.end, confidence=logprob_to_conf(seg.avg_logprob), timestamp=audio.timestamp)` +- logprob → confidence: `max(0.0, min(1.0, math.exp(avg_logprob)))` + +**NeMoDiarizer** (`processing/diarize.py`) +- `initialize()`: load `nemo.collections.asr.models.ClusteringDiarizer` config +- Diarization requires writing audio to a temp WAV file (NeMo does not accept numpy directly) +- `diarize(audio)`: write `audio.samples` to temp WAV with `soundfile.write()`, run diarizer, parse RTTM output +- RTTM line: `SPEAKER 1 ` +- Return list of `TranscriptSegment(speaker_label=label, start_time=start, end_time=start+dur, diarization_confidence=conf, timestamp=audio.timestamp)` +- Clean up temp files in `shutdown()` + +**DearPyGuiOverlay** (`hud/overlay.py`) +- Run DearPyGui in a separate thread (it requires its own thread with an event loop) +- `start()`: launch the thread, create a frameless transparent window pinned to a corner +- `update_status()`, `show_alert()`, `update_stats()`: use `dearpygui.set_value()` on named items; guard with thread-safe queue if needed +- `stop()`: call `dearpygui.stop_dearpygui()`, join the thread +- Keep the overlay non-blocking to the async pipeline: all DPG calls from the DPG thread, async methods just enqueue updates + +--- + +## Config Keys Relevant to Part B + +All env vars use the `NIGHTSHADE_` prefix. Can also be set in a `.env` file. + +| Setting | Env Var | Default | Used By | +|---------|---------|---------|---------| +| `capture_interval_ms` | `NIGHTSHADE_CAPTURE_INTERVAL_MS` | `500` | orchestrator `_capture_loop` — sleep between frames | +| `ocr_region` | `NIGHTSHADE_OCR_REGION` | `None` | Passed as `region` to `ScreenCapture.capture_frame()` — format `"x,y,w,h"` | +| `audio_device` | `NIGHTSHADE_AUDIO_DEVICE` | `None` | Passed to `AudioCapture.start(device=...)` | +| `whisper_model` | `NIGHTSHADE_WHISPER_MODEL` | `"base"` | Passed to `SpeechToText.initialize(model=...)` — `"tiny"`, `"base"`, `"small"`, `"medium"`, `"large-v3"` | +| `dashcam_buffer_seconds` | `NIGHTSHADE_DASHCAM_BUFFER_SECONDS` | `300` | `RingBuffer(max_seconds=...)` — 5 min rolling window | +| `evidence_chunk_size` | `NIGHTSHADE_EVIDENCE_CHUNK_SIZE` | `1048576` | `EvidenceStreamManager` — 1 MB chunks for file upload | +| `buffer_max_messages` | `NIGHTSHADE_BUFFER_MAX_MESSAGES` | `1000` | Chat/voice queue depth cap | +| `log_level` | `NIGHTSHADE_LOG_LEVEL` | `"INFO"` | `logging.basicConfig` | + +--- + +## Data Flow + +``` +ScreenCapture.capture_frame() + → RingBuffer.append(frame) # dashcam rolling window + → OCREngine.extract_text(frame) + → parse "[username]: message" + → ChatStreamManager.enqueue() + → gRPC StreamChatMessages + +AudioCapture.read_chunk() + → SpeechToText.transcribe() # faster-whisper + → SpeakerDiarizer.diarize() # NeMo (parallel) + → merge diarization labels into segments + → VoiceStreamManager.enqueue() + → gRPC StreamVoiceTranscripts + +Server AlertNotification received + → HUDOverlay.show_alert() + → RingBuffer.lock_and_save() # write JSON marker (Part A stub) + → NVENCVideoRecorder encodes frames # Part B upgrades this + → EvidenceStreamManager.upload_file() + → gRPC UploadEvidence (chunked) +``` + +--- + +## Testing Instructions + +### Run existing tests (must still pass after your changes) +``` +cd C:\path\to\client +python -m pytest tests/ -v +``` + +### Mock mode (no GPU, no capture hardware needed — tests the pipeline wiring) +``` +python -m nightshade_client start --mock --server localhost:50051 --api-key test --operator-id 00000000-0000-0000-0000-000000000000 --game "Test" +``` + +### Real mode (requires running Nightshade server) +``` +python -m nightshade_client start \ + --server nightshade-server:50051 \ + --api-key YOUR_KEY \ + --operator-id YOUR_UUID \ + --game "Roblox" +``` + +### Verify individual backends without the full pipeline +```python +import asyncio, numpy as np +from nightshade_client.capture.screen import MSSScreenCapture + +async def test(): + cap = MSSScreenCapture() + await cap.start() + frame = await cap.capture_frame() + print(f"Frame: {frame.width}x{frame.height}, dtype={frame.image.dtype}") + await cap.stop() + +asyncio.run(test()) +``` + +### Injecting real backends into the orchestrator +```python +from nightshade_client.pipeline.orchestrator import PipelineOrchestrator +from nightshade_client.capture.screen import MSSScreenCapture +from nightshade_client.capture.audio import SoundDeviceCapture +from nightshade_client.processing.ocr import PaddleOCREngine +from nightshade_client.processing.stt import WhisperSTT +from nightshade_client.processing.diarize import NeMoDiarizer +from nightshade_client.hud.overlay import DearPyGuiOverlay + +orchestrator = PipelineOrchestrator( + settings=settings, + screen=MSSScreenCapture(), + audio=SoundDeviceCapture(), + ocr=PaddleOCREngine(), + stt=WhisperSTT(), + diarizer=NeMoDiarizer(), + hud=DearPyGuiOverlay(), + use_mocks=False, +) +``` + +Note: `use_mocks=False` with no components provided still falls back to mocks (safe). +Pass real instances explicitly to activate real backends. + +--- + +## Key File Paths + +| File | Purpose | +|------|---------| +| `nightshade_client/pipeline/base.py` | ABCs, dataclasses, Mock implementations — read-only reference | +| `nightshade_client/pipeline/orchestrator.py` | Main pipeline loop — read-only, shows how ABCs are called | +| `nightshade_client/capture/screen.py` | Add `MSSScreenCapture` here | +| `nightshade_client/capture/audio.py` | Add `SoundDeviceCapture` here | +| `nightshade_client/capture/video.py` | Add `NVENCVideoRecorder` here | +| `nightshade_client/processing/ocr.py` | Add `PaddleOCREngine` here | +| `nightshade_client/processing/stt.py` | Add `WhisperSTT` here | +| `nightshade_client/processing/diarize.py` | Add `NeMoDiarizer` here | +| `nightshade_client/hud/overlay.py` | Add `DearPyGuiOverlay` here | +| `requirements-windows.txt` | Windows-only deps — all libraries already listed | +| `pyproject.toml` | Base deps — do not add Windows-specific libs here | diff --git a/client/WINDOWS_BUILD.md b/client/WINDOWS_BUILD.md new file mode 100644 index 0000000..9be705e --- /dev/null +++ b/client/WINDOWS_BUILD.md @@ -0,0 +1,347 @@ +# Nightshade Client — Windows Build Guide + +Step-by-step environment setup for Windows with NVIDIA GPU acceleration. + +--- + +## Prerequisites + +- Windows 10 or Windows 11 (64-bit) +- NVIDIA GPU — RTX 3060 or better recommended (8 GB VRAM minimum for all models simultaneously) +- CUDA 12.x +- cuDNN 8.x +- Python 3.11 or 3.12 (3.11 recommended for NeMo compatibility) +- Git for Windows + +--- + +## 1. CUDA Setup + +``` +1. Download CUDA 12.x installer from: + https://developer.nvidia.com/cuda-downloads + Select: Windows → x86_64 → 11 or 12 → exe (local) + +2. Run installer with default options (includes driver if needed) + +3. Verify CUDA installed correctly: + nvcc --version + Expected output: Cuda compilation tools, release 12.x + +4. Download cuDNN 8.x from: + https://developer.nvidia.com/cudnn + (Requires free NVIDIA Developer account) + Select the version matching your CUDA 12.x install + +5. Extract cuDNN zip. Copy files into CUDA install directory: + cuDNN\bin\* → C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.x\bin\ + cuDNN\include\* → C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.x\include\ + cuDNN\lib\x64\* → C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.x\lib\x64\ + +6. Verify cuDNN is on PATH: + where cudnn64_8.dll + Expected: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.x\bin\cudnn64_8.dll +``` + +--- + +## 2. Python Environment + +``` +# Clone the repo (or copy the client directory to the Windows machine) +git clone https://github.com/nullEFFORT/homelab-tools.git +cd homelab-tools\nightshade\client + +# Create virtual environment +python -m venv .venv + +# Activate +.venv\Scripts\activate + +# Install base SDK dependencies +pip install -e ".[dev]" + +# Install Windows GPU dependencies +pip install -r requirements-windows.txt +``` + +The `requirements-windows.txt` contains: +``` +paddlepaddle-gpu>=2.6 +paddleocr>=2.7 +faster-whisper>=1.0 +nemo_toolkit[asr]>=1.22 +sounddevice>=0.4 +mss>=9.0 +dearpygui>=1.10 +``` + +--- + +## 3. PaddleOCR Setup + +PaddlePaddle-GPU must be installed separately before paddleocr to get the CUDA build. + +``` +# Install GPU-enabled PaddlePaddle first (match your CUDA version) +pip install paddlepaddle-gpu==2.6.1.post120 -f https://www.paddlepaddle.org.cn/whl/windows/mkl/avx/stable.html + +# Then install paddleocr +pip install paddleocr + +# Verify GPU is detected +python -c "import paddle; print('GPU available:', paddle.device.is_compiled_with_cuda())" +Expected: GPU available: True + +# Run a full smoke test (downloads models on first run — ~500 MB) +python -c " +from paddleocr import PaddleOCR +import numpy as np +ocr = PaddleOCR(use_angle_cls=True, lang='en', use_gpu=True) +dummy = np.zeros((100, 400, 3), dtype='uint8') +result = ocr.ocr(dummy, cls=True) +print('PaddleOCR OK, result:', result) +" +``` + +--- + +## 4. Faster Whisper Setup + +``` +pip install faster-whisper + +# Verify CUDA transcription works (downloads model on first run — ~150 MB for 'base') +python -c " +from faster_whisper import WhisperModel +import numpy as np +m = WhisperModel('base', device='cuda', compute_type='float16') +audio = np.zeros(16000, dtype='float32') # 1 second of silence +segments, info = m.transcribe(audio, beam_size=5) +print('WhisperModel OK, language detected:', info.language) +" +``` + +Available model sizes (trade VRAM for accuracy): +| Model | VRAM | Speed | +|-------|------|-------| +| `tiny` | ~1 GB | fastest | +| `base` | ~1 GB | fast (default) | +| `small` | ~2 GB | balanced | +| `medium` | ~5 GB | accurate | +| `large-v3` | ~10 GB | most accurate | + +Set model via env var: `NIGHTSHADE_WHISPER_MODEL=small` + +--- + +## 5. NeMo Diarization Setup + +NeMo has a large dependency chain. Install with the `[asr]` extra. + +``` +# Install NeMo ASR (includes diarization models) +pip install nemo_toolkit[asr] + +# Some NeMo deps may need manual installation on Windows: +pip install soundfile librosa + +# Verify NeMo loads +python -c "import nemo; print('NeMo version:', nemo.__version__)" + +# Verify diarization model loads (downloads ~400 MB on first run) +python -c " +import nemo.collections.asr as nemo_asr +print('NeMo ASR collections OK') +" +``` + +Note: NeMo diarization requires writing audio to a temp WAV file. The +`NeMoDiarizer` implementation handles this automatically. + +--- + +## 6. FFmpeg with NVENC + +The `NVENCVideoRecorder` uses FFmpeg as a subprocess for hardware-accelerated encoding. + +``` +1. Download FFmpeg from: https://www.gyan.dev/ffmpeg/builds/ + Select: ffmpeg-release-essentials.zip + +2. Extract to C:\ffmpeg\ + +3. Add to system PATH: + System Properties → Environment Variables → Path → Add: C:\ffmpeg\bin + +4. Verify NVENC is available: + ffmpeg -encoders | findstr nvenc + Expected output includes: V..... h264_nvenc + +5. Test NVENC encoding: + ffmpeg -f lavfi -i color=c=black:s=1920x1080:r=30 -t 2 -c:v h264_nvenc -y test_nvenc.mp4 + Check: test_nvenc.mp4 exists and is non-zero size +``` + +--- + +## 7. GPU Verification One-Liners + +Run these before starting any real capture session to confirm all GPU stacks are working: + +``` +# PyTorch CUDA (required by NeMo and faster-whisper internally) +python -c "import torch; print(f'CUDA: {torch.cuda.is_available()}, GPU: {torch.cuda.get_device_name(0)}')" + +# PaddlePaddle GPU +python -c "import paddle; print(f'PaddlePaddle GPU: {paddle.device.is_compiled_with_cuda()}')" + +# Whisper on CUDA +python -c "from faster_whisper import WhisperModel; m = WhisperModel('tiny', device='cuda'); print('Whisper CUDA OK')" + +# Screen capture +python -c "import mss; s = mss.mss(); print('Monitors:', len(s.monitors) - 1); s.close()" + +# Audio devices +python -c "import sounddevice; print(sounddevice.query_devices())" +``` + +--- + +## 8. Running the Client + +### Mock mode (no GPU, no server needed — verify pipeline wiring) +``` +python -m nightshade_client start --mock --server localhost:50051 --api-key test --operator-id 00000000-0000-0000-0000-000000000000 --game "Roblox" +``` + +### Real mode (requires running Nightshade server) +``` +python -m nightshade_client start ^ + --server nightshade-server:50051 ^ + --api-key YOUR_API_KEY ^ + --operator-id YOUR_OPERATOR_UUID ^ + --game "Roblox" +``` + +### With explicit game session metadata +``` +python -m nightshade_client start ^ + --server nightshade-server:50051 ^ + --api-key YOUR_KEY ^ + --operator-id YOUR_UUID ^ + --game "Roblox" ^ + --universe-id 123456789 ^ + --place-id 987654321 ^ + --roblox-account YourRobloxUsername +``` + +### Disable HUD overlay (headless / RDP sessions) +``` +python -m nightshade_client start --mock --no-hud --server localhost:50051 ... +``` + +--- + +## 9. Configuration via .env + +Create a `.env` file in the client directory to avoid passing flags every time: + +``` +NIGHTSHADE_SERVER_HOST=nightshade-server +NIGHTSHADE_SERVER_PORT=50051 +NIGHTSHADE_API_KEY=your-api-key-here +NIGHTSHADE_OPERATOR_ID=your-uuid-here +NIGHTSHADE_WHISPER_MODEL=base +NIGHTSHADE_CAPTURE_INTERVAL_MS=500 +NIGHTSHADE_OCR_REGION=0,800,1920,280 +NIGHTSHADE_AUDIO_DEVICE=Stereo Mix +NIGHTSHADE_LOG_LEVEL=INFO +``` + +`OCR_REGION` format is `x,y,w,h` — set this to cover just the chat area of the game +window to reduce OCR noise and improve performance. + +--- + +## 10. Troubleshooting + +### CUDA out of memory +Reduce the Whisper model size: +``` +NIGHTSHADE_WHISPER_MODEL=tiny +``` +Or close other GPU-heavy applications before running. + +### cuDNN not found (error loading libcudnn) +``` +# Ensure cuDNN DLLs are in PATH: +set PATH=C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.x\bin;%PATH% + +# Verify: +where cudnn64_8.dll +``` + +### Audio device not found +List available devices and find the correct name: +``` +python -c "import sounddevice; print(sounddevice.query_devices())" +``` +Set the device name in `.env`: `NIGHTSHADE_AUDIO_DEVICE=Stereo Mix` + +For game audio capture, use "Stereo Mix" (loopback) or a virtual audio cable like VB-Audio Virtual Cable. + +### PaddleOCR returns empty results +Check that the OCR region covers the chat area. Use None first (full screen) to verify +the model is working, then narrow the region. Also confirm `use_gpu=True` is not silently +falling back — check the PaddleOCR log output for "use_gpu" confirmation. + +### NeMo diarization fails with temp file errors +Ensure the user has write access to `%TEMP%`. NeMo writes intermediate WAV files there. +If running in a restricted environment, set `TMPDIR` to a writable path: +``` +set TMPDIR=C:\Users\YourUser\AppData\Local\Temp +``` + +### DearPyGui HUD not visible +DearPyGui requires a display. If running over RDP, the HUD may not render. Use `--no-hud` +flag in that case: +``` +python -m nightshade_client start --no-hud ... +``` + +### gRPC connection refused +Verify the server is running and reachable: +``` +python -c " +import grpc +channel = grpc.insecure_channel('nightshade-server:50051') +grpc.channel_ready_future(channel).result(timeout=5) +print('gRPC channel OK') +" +``` + +### ImportError for any dependency +Verify the virtual environment is active: +``` +.venv\Scripts\activate +python -c "import nightshade_client; print('SDK OK')" +``` + +--- + +## 11. Running Tests + +``` +# Activate venv +.venv\Scripts\activate + +# Run all tests (Part A — mocked, no GPU needed) +python -m pytest tests/ -v + +# Run with coverage +python -m pytest tests/ -v --cov=nightshade_client --cov-report=term-missing +``` + +Tests live in `tests/test_connection/`, `tests/test_session/`, `tests/test_streams/`, +and `tests/test_pipeline/`. All tests use mocks and do not require GPU or a live server. diff --git a/client/generate_proto.sh b/client/generate_proto.sh new file mode 100755 index 0000000..39a1f9a --- /dev/null +++ b/client/generate_proto.sh @@ -0,0 +1,23 @@ +#!/usr/bin/env bash +# Generate Python gRPC stubs from nightshade.proto +# Output goes into nightshade_client/grpc_stubs/ so they are importable as: +# from nightshade_client.grpc_stubs import nightshade_pb2, nightshade_pb2_grpc + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +cd "$SCRIPT_DIR" + +python -m grpc_tools.protoc \ + -I proto \ + --python_out=nightshade_client/grpc_stubs \ + --pyi_out=nightshade_client/grpc_stubs \ + --grpc_python_out=nightshade_client/grpc_stubs \ + proto/nightshade.proto + +# Fix import: protoc emits bare `import nightshade_pb2` but stubs live inside +# nightshade_client.grpc_stubs package, so patch to package-relative. +sed -i 's/^import nightshade_pb2 as nightshade__pb2$/from nightshade_client.grpc_stubs import nightshade_pb2 as nightshade__pb2/' \ + nightshade_client/grpc_stubs/nightshade_pb2_grpc.py + +echo "Proto stubs generated in nightshade_client/grpc_stubs/" diff --git a/client/nightshade_client/__init__.py b/client/nightshade_client/__init__.py new file mode 100644 index 0000000..562472a --- /dev/null +++ b/client/nightshade_client/__init__.py @@ -0,0 +1,3 @@ +"""Nightshade monitoring client SDK.""" + +__version__ = "0.1.0" diff --git a/client/nightshade_client/__main__.py b/client/nightshade_client/__main__.py new file mode 100644 index 0000000..ecbd7fb --- /dev/null +++ b/client/nightshade_client/__main__.py @@ -0,0 +1,109 @@ +""" +Nightshade client CLI entry point. + +Usage: + python -m nightshade_client start --server HOST:PORT --api-key KEY --operator-id UUID --game "Name" + python -m nightshade_client start --mock --no-hud --server localhost:50051 + python -m nightshade_client --help +""" + +import argparse +import asyncio +import logging +import sys + + +def parse_args(argv: list[str] | None = None) -> argparse.Namespace: + parser = argparse.ArgumentParser( + prog="nightshade_client", + description="Nightshade monitoring client", + ) + sub = parser.add_subparsers(dest="command", required=True) + + start = sub.add_parser("start", help="Start monitoring pipeline") + start.add_argument("--server", default="localhost:50051", + help="Server host:port (default: localhost:50051)") + start.add_argument("--api-key", default="", help="API key for authentication") + start.add_argument("--operator-id", default="", help="Operator UUID") + start.add_argument("--game", default="Roblox", help="Game name") + start.add_argument("--universe-id", default="", help="Roblox universe ID") + start.add_argument("--place-id", default="", help="Roblox place ID") + start.add_argument("--server-id", default="", help="Roblox server ID") + start.add_argument("--roblox-account", default="", help="Roblox account name") + start.add_argument("--roblox-account-id", default="", help="Roblox account ID") + start.add_argument("--mock", action="store_true", help="Use mock backends (no GPU/capture needed)") + start.add_argument("--no-hud", action="store_true", help="Disable HUD overlay") + start.add_argument("--config", default=None, help="Path to .env config file") + start.add_argument("--log-level", default=None, help="Log level override") + + return parser.parse_args(argv) + + +def main(argv: list[str] | None = None) -> int: + args = parse_args(argv) + + if args.command == "start": + return _cmd_start(args) + + return 1 + + +def _cmd_start(args: argparse.Namespace) -> int: + # Import here to avoid slow imports on --help + from nightshade_client.config import ClientSettings + from nightshade_client.pipeline.base import MockHUDOverlay + from nightshade_client.pipeline.orchestrator import PipelineOrchestrator + + # Build settings from env/config file, then override with CLI args + env_file = args.config or ".env" + try: + settings = ClientSettings(_env_file=env_file) + except Exception: + settings = ClientSettings() + + # CLI overrides + if args.server: + host, _, port = args.server.partition(":") + if host: + settings.server_host = host + if port: + settings.server_port = int(port) + if args.api_key: + settings.api_key = args.api_key + if args.operator_id: + settings.operator_id = args.operator_id + if args.log_level: + settings.log_level = args.log_level + + # Configure logging + logging.basicConfig( + level=getattr(logging, settings.log_level.upper(), logging.INFO), + format="%(asctime)s %(levelname)-8s %(name)s | %(message)s", + datefmt="%H:%M:%S", + ) + + hud = MockHUDOverlay() if args.no_hud else None + + orchestrator = PipelineOrchestrator( + settings=settings, + hud=hud, + use_mocks=args.mock, + ) + + try: + asyncio.run(orchestrator.run( + game_name=args.game, + universe_id=args.universe_id, + place_id=args.place_id, + server_id=args.server_id, + roblox_account=args.roblox_account, + roblox_account_id=args.roblox_account_id, + )) + except KeyboardInterrupt: + pass + + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/client/nightshade_client/capture/__init__.py b/client/nightshade_client/capture/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/client/nightshade_client/capture/audio.py b/client/nightshade_client/capture/audio.py new file mode 100644 index 0000000..c40e940 --- /dev/null +++ b/client/nightshade_client/capture/audio.py @@ -0,0 +1,12 @@ +"""Audio capture — Part B stub. + +Implement RealAudioCapture using sounddevice on Windows. +""" + +from nightshade_client.pipeline.base import ( # noqa: F401 + AudioCapture, + CapturedAudio, + MockAudioCapture, +) + +__all__ = ["AudioCapture", "CapturedAudio", "MockAudioCapture"] diff --git a/client/nightshade_client/capture/screen.py b/client/nightshade_client/capture/screen.py new file mode 100644 index 0000000..f5412ae --- /dev/null +++ b/client/nightshade_client/capture/screen.py @@ -0,0 +1,12 @@ +"""Screen capture — Part B stub. + +Implement RealScreenCapture using mss or similar on Windows. +""" + +from nightshade_client.pipeline.base import ( # noqa: F401 + CapturedFrame, + MockScreenCapture, + ScreenCapture, +) + +__all__ = ["ScreenCapture", "CapturedFrame", "MockScreenCapture"] diff --git a/client/nightshade_client/capture/video.py b/client/nightshade_client/capture/video.py new file mode 100644 index 0000000..d5b01af --- /dev/null +++ b/client/nightshade_client/capture/video.py @@ -0,0 +1,12 @@ +"""Video recording — Part B stub. + +Implement NVENCVideoRecorder using NVIDIA NVENC on Windows. +""" + +from nightshade_client.pipeline.base import ( # noqa: F401 + CapturedFrame, + MockVideoRecorder, + VideoRecorder, +) + +__all__ = ["VideoRecorder", "CapturedFrame", "MockVideoRecorder"] diff --git a/client/nightshade_client/config.py b/client/nightshade_client/config.py new file mode 100644 index 0000000..f78ae61 --- /dev/null +++ b/client/nightshade_client/config.py @@ -0,0 +1,61 @@ +""" +Nightshade client configuration via pydantic-settings. +Values are loaded from environment variables or a .env file. +""" + +from functools import lru_cache +from typing import Optional + +from pydantic import Field +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class ClientSettings(BaseSettings): + model_config = SettingsConfigDict( + env_prefix="NIGHTSHADE_", + env_file=".env", + env_file_encoding="utf-8", + extra="ignore", + ) + + # ── Server connection ────────────────────────────────────────────────── + server_host: str = "localhost" + server_port: int = 50051 + api_key: str = "" + use_tls: bool = False + + # ── Operator ─────────────────────────────────────────────────────────── + operator_id: str = "" + + # ── Capture settings ─────────────────────────────────────────────────── + capture_interval_ms: int = 500 + ocr_region: Optional[tuple[int, int, int, int]] = None # x, y, w, h + audio_device: Optional[str] = None + whisper_model: str = "base" + + # ── Buffering ────────────────────────────────────────────────────────── + buffer_max_messages: int = 1000 + heartbeat_interval_s: int = 10 + evidence_chunk_size: int = Field(default=1_048_576, description="1 MB") + + # ── Dashcam ──────────────────────────────────────────────────────────── + dashcam_buffer_seconds: int = 300 + + # ── Streaming ────────────────────────────────────────────────────────── + chat_batch_size: int = 50 + chat_batch_timeout_s: float = 2.0 + voice_batch_size: int = 50 + voice_batch_timeout_s: float = 2.0 + + # ── Reconnection ────────────────────────────────────────────────────── + reconnect_min_s: float = 1.0 + reconnect_max_s: float = 60.0 + + # ── Logging ──────────────────────────────────────────────────────────── + log_level: str = "INFO" + + +@lru_cache +def get_settings() -> ClientSettings: + """Return cached ClientSettings instance (reads .env once).""" + return ClientSettings() diff --git a/client/nightshade_client/connection/__init__.py b/client/nightshade_client/connection/__init__.py new file mode 100644 index 0000000..cf2b8b0 --- /dev/null +++ b/client/nightshade_client/connection/__init__.py @@ -0,0 +1,9 @@ +from nightshade_client.connection.auth import ApiKeyMetadata +from nightshade_client.connection.channel import ChannelManager +from nightshade_client.connection.health import HealthChecker + +__all__ = [ + "ApiKeyMetadata", + "ChannelManager", + "HealthChecker", +] diff --git a/client/nightshade_client/connection/auth.py b/client/nightshade_client/connection/auth.py new file mode 100644 index 0000000..eed3f47 --- /dev/null +++ b/client/nightshade_client/connection/auth.py @@ -0,0 +1,22 @@ +""" +API key metadata injection for gRPC calls. +""" + +import logging + +logger = logging.getLogger(__name__) + + +class ApiKeyMetadata: + """Injects x-api-key metadata into gRPC calls.""" + + def __init__(self, api_key: str) -> None: + self._api_key = api_key + + def get_metadata(self) -> list[tuple[str, str]]: + """Return metadata list for gRPC call: [("x-api-key", key)].""" + return [("x-api-key", self._api_key)] + + def __call__(self) -> list[tuple[str, str]]: + """Allow using as callable for convenience.""" + return self.get_metadata() diff --git a/client/nightshade_client/connection/channel.py b/client/nightshade_client/connection/channel.py new file mode 100644 index 0000000..afc7ad2 --- /dev/null +++ b/client/nightshade_client/connection/channel.py @@ -0,0 +1,109 @@ +""" +gRPC channel factory with reconnection logic. +""" + +import asyncio +import logging +import random +from typing import Optional + +import grpc +import grpc.aio + +from nightshade_client.config import ClientSettings +from nightshade_client.grpc_stubs import nightshade_pb2_grpc + +logger = logging.getLogger(__name__) + + +class ChannelManager: + """gRPC channel factory with reconnection logic.""" + + _CHANNEL_OPTIONS = [ + ("grpc.keepalive_time_ms", 10_000), + ("grpc.keepalive_timeout_ms", 5_000), + ("grpc.keepalive_permit_without_calls", 1), + ] + + def __init__(self, settings: ClientSettings) -> None: + self._settings = settings + self._channel: Optional[grpc.aio.Channel] = None + self._closed: bool = False + + async def connect(self) -> None: + """Create channel (insecure or TLS based on settings.use_tls).""" + target = f"{self._settings.server_host}:{self._settings.server_port}" + + if self._settings.use_tls: + credentials = grpc.ssl_channel_credentials() + self._channel = grpc.aio.secure_channel( + target, + credentials, + options=self._CHANNEL_OPTIONS, + ) + logger.info("Connected to %s (TLS)", target) + else: + self._channel = grpc.aio.insecure_channel( + target, + options=self._CHANNEL_OPTIONS, + ) + logger.info("Connected to %s (insecure)", target) + + self._closed = False + + def get_stub(self) -> nightshade_pb2_grpc.NightshadeServiceStub: + """Return a stub bound to the current channel.""" + if self._channel is None: + raise RuntimeError("Channel is not connected. Call connect() first.") + return nightshade_pb2_grpc.NightshadeServiceStub(self._channel) + + async def reconnect(self) -> None: + """Reconnect with exponential backoff (min_s -> max_s with jitter).""" + await self.close() + + delay = self._settings.reconnect_min_s + attempt = 0 + + while True: + attempt += 1 + jitter = random.uniform(0.0, 1.0) + wait = min(delay, self._settings.reconnect_max_s) + jitter + logger.info( + "Reconnect attempt %d: waiting %.2fs before connecting...", + attempt, + wait, + ) + await asyncio.sleep(wait) + + try: + await self.connect() + logger.info("Reconnected successfully on attempt %d.", attempt) + return + except Exception as exc: + logger.warning("Reconnect attempt %d failed: %s", attempt, exc) + delay = min(delay * 2, self._settings.reconnect_max_s) + + async def close(self) -> None: + """Close the channel.""" + if self._channel is not None: + await self._channel.close() + logger.debug("Channel closed.") + self._channel = None + self._closed = True + + @property + def is_connected(self) -> bool: + """Whether channel exists and is not closed.""" + return self._channel is not None and not self._closed + + async def __aenter__(self) -> "ChannelManager": + await self.connect() + return self + + async def __aexit__( + self, + exc_type: Optional[type], + exc_val: Optional[BaseException], + exc_tb: Optional[object], + ) -> None: + await self.close() diff --git a/client/nightshade_client/connection/health.py b/client/nightshade_client/connection/health.py new file mode 100644 index 0000000..fd2f0cb --- /dev/null +++ b/client/nightshade_client/connection/health.py @@ -0,0 +1,122 @@ +""" +Periodic connectivity health checker with event callbacks. +""" + +import asyncio +import logging +from typing import Callable, Optional + +import grpc + +from nightshade_client.connection.channel import ChannelManager + +logger = logging.getLogger(__name__) + +# Channel states that indicate an active, usable connection. +_READY_STATES = frozenset([grpc.ChannelConnectivity.READY]) + +# Channel states that indicate the connection is not available. +_NOT_READY_STATES = frozenset([ + grpc.ChannelConnectivity.IDLE, + grpc.ChannelConnectivity.CONNECTING, + grpc.ChannelConnectivity.TRANSIENT_FAILURE, + grpc.ChannelConnectivity.SHUTDOWN, +]) + + +class HealthChecker: + """Periodic connectivity check with event callbacks.""" + + def __init__( + self, + channel_manager: ChannelManager, + check_interval_s: float = 15.0, + ) -> None: + self._channel = channel_manager + self._interval = check_interval_s + self._connected: bool = False + self._on_connect: Optional[Callable] = None + self._on_disconnect: Optional[Callable] = None + self._task: Optional[asyncio.Task] = None + + def on_connect(self, callback: Callable) -> None: + """Register callback for connection established.""" + self._on_connect = callback + + def on_disconnect(self, callback: Callable) -> None: + """Register callback for connection lost.""" + self._on_disconnect = callback + + async def start(self) -> None: + """Start background health check loop.""" + if self._task is not None and not self._task.done(): + logger.warning("HealthChecker already running; ignoring start().") + return + self._task = asyncio.create_task(self._check_loop()) + logger.debug("HealthChecker started (interval=%.1fs).", self._interval) + + async def stop(self) -> None: + """Cancel the health check task.""" + if self._task is not None and not self._task.done(): + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + logger.debug("HealthChecker stopped.") + self._task = None + + async def _check_loop(self) -> None: + """Periodic loop: check channel state, emit events on transitions.""" + while True: + try: + now_connected = self._is_channel_ready() + + if now_connected and not self._connected: + self._connected = True + logger.info("Channel is READY — firing on_connect callback.") + if self._on_connect is not None: + try: + result = self._on_connect() + if asyncio.iscoroutine(result): + await result + except Exception as exc: + logger.exception("on_connect callback raised: %s", exc) + + elif not now_connected and self._connected: + self._connected = False + logger.warning("Channel is NOT READY — firing on_disconnect callback.") + if self._on_disconnect is not None: + try: + result = self._on_disconnect() + if asyncio.iscoroutine(result): + await result + except Exception as exc: + logger.exception("on_disconnect callback raised: %s", exc) + + except Exception as exc: + logger.exception("Unexpected error in health check loop: %s", exc) + + await asyncio.sleep(self._interval) + + def _is_channel_ready(self) -> bool: + """Return True if the underlying channel reports READY state.""" + if not self._channel.is_connected: + return False + + # Access the raw grpc.aio.Channel to read connectivity state. + raw_channel = self._channel._channel # type: ignore[attr-defined] + if raw_channel is None: + return False + + try: + state = raw_channel.get_state(try_to_connect=False) + return state in _READY_STATES + except Exception as exc: + logger.debug("Could not read channel state: %s", exc) + return False + + @property + def is_healthy(self) -> bool: + """Whether the channel was last observed as READY.""" + return self._connected diff --git a/client/nightshade_client/grpc_stubs/__init__.py b/client/nightshade_client/grpc_stubs/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/client/nightshade_client/hud/__init__.py b/client/nightshade_client/hud/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/client/nightshade_client/hud/overlay.py b/client/nightshade_client/hud/overlay.py new file mode 100644 index 0000000..d67af2d --- /dev/null +++ b/client/nightshade_client/hud/overlay.py @@ -0,0 +1,8 @@ +"""HUD overlay — Part B stub. + +Implement DearPyGuiOverlay using DearPyGui on Windows. +""" + +from nightshade_client.pipeline.base import HUDOverlay, MockHUDOverlay # noqa: F401 + +__all__ = ["HUDOverlay", "MockHUDOverlay"] diff --git a/client/nightshade_client/pipeline/__init__.py b/client/nightshade_client/pipeline/__init__.py new file mode 100644 index 0000000..f87fadd --- /dev/null +++ b/client/nightshade_client/pipeline/__init__.py @@ -0,0 +1,2 @@ +from nightshade_client.pipeline.base import * # noqa: F401,F403 +from nightshade_client.pipeline.buffer import RingBuffer # noqa: F401 diff --git a/client/nightshade_client/pipeline/base.py b/client/nightshade_client/pipeline/base.py new file mode 100644 index 0000000..5df9155 --- /dev/null +++ b/client/nightshade_client/pipeline/base.py @@ -0,0 +1,373 @@ +""" +Pipeline ABCs, dataclasses, and mock implementations. + +These define the interface contract between Part A (platform-independent SDK) +and Part B (Windows capture/processing). Windows Claude Code implements the +ABCs; Part A calls them via the orchestrator. +""" + +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from pathlib import Path +from typing import Optional +import random +import string +import time + +import numpy as np + + +# ─── Data Classes ──────────────────────────────────────────────────────────── + + +@dataclass +class CapturedFrame: + """A single screen capture frame.""" + + image: np.ndarray # HxWxC uint8 BGR + timestamp: float # unix seconds + width: int = 0 + height: int = 0 + + def __post_init__(self): + if self.width == 0 and self.image is not None: + self.height, self.width = self.image.shape[:2] + + +@dataclass +class CapturedAudio: + """A chunk of captured audio.""" + + samples: np.ndarray # float32 mono PCM + sample_rate: int = 16000 + timestamp: float = 0.0 + duration_s: float = 0.0 + + +@dataclass +class OCRResult: + """OCR output from a single frame.""" + + texts: list[str] = field(default_factory=list) + confidences: list[float] = field(default_factory=list) + boxes: list[tuple[int, int, int, int]] = field(default_factory=list) # x,y,w,h + timestamp: float = 0.0 + frame_width: int = 0 + frame_height: int = 0 + + +@dataclass +class TranscriptSegment: + """A single speech-to-text segment.""" + + text: str = "" + speaker_label: str = "" + confidence: float = 0.0 + diarization_confidence: float = 0.0 + start_time: float = 0.0 + end_time: float = 0.0 + timestamp: float = 0.0 + + +# ─── Abstract Base Classes ─────────────────────────────────────────────────── + + +class ScreenCapture(ABC): + """Captures screen frames.""" + + @abstractmethod + async def start(self) -> None: + """Initialize capture resources.""" + + @abstractmethod + async def capture_frame(self, region: Optional[tuple[int, int, int, int]] = None) -> CapturedFrame: + """Capture a single frame, optionally from a specific region (x,y,w,h).""" + + @abstractmethod + async def stop(self) -> None: + """Release capture resources.""" + + +class AudioCapture(ABC): + """Captures audio from a device.""" + + @abstractmethod + async def start(self, device: Optional[str] = None, sample_rate: int = 16000) -> None: + """Start audio capture from the specified device.""" + + @abstractmethod + async def read_chunk(self, duration_s: float = 5.0) -> CapturedAudio: + """Read a chunk of audio of the given duration.""" + + @abstractmethod + async def stop(self) -> None: + """Stop audio capture and release resources.""" + + +class VideoRecorder(ABC): + """Records video (e.g., NVENC-accelerated).""" + + @abstractmethod + async def start(self, output_path: Path, fps: int = 30) -> None: + """Start recording to the given path.""" + + @abstractmethod + async def write_frame(self, frame: CapturedFrame) -> None: + """Write a frame to the recording.""" + + @abstractmethod + async def stop(self) -> Path: + """Stop recording, finalize file, return path.""" + + +class OCREngine(ABC): + """Runs OCR on captured frames.""" + + @abstractmethod + async def initialize(self, model: str = "default") -> None: + """Load OCR model.""" + + @abstractmethod + async def extract_text(self, frame: CapturedFrame) -> OCRResult: + """Extract text from a frame.""" + + @abstractmethod + async def shutdown(self) -> None: + """Unload model and free resources.""" + + +class SpeechToText(ABC): + """Transcribes audio to text.""" + + @abstractmethod + async def initialize(self, model: str = "base") -> None: + """Load STT model.""" + + @abstractmethod + async def transcribe(self, audio: CapturedAudio) -> list[TranscriptSegment]: + """Transcribe an audio chunk into segments.""" + + @abstractmethod + async def shutdown(self) -> None: + """Unload model and free resources.""" + + +class SpeakerDiarizer(ABC): + """Identifies speaker segments in audio.""" + + @abstractmethod + async def initialize(self) -> None: + """Load diarization model.""" + + @abstractmethod + async def diarize(self, audio: CapturedAudio) -> list[TranscriptSegment]: + """Assign speaker labels to audio segments.""" + + @abstractmethod + async def shutdown(self) -> None: + """Unload model and free resources.""" + + +class HUDOverlay(ABC): + """Heads-up display overlay for operator feedback.""" + + @abstractmethod + async def start(self) -> None: + """Initialize and show the HUD window.""" + + @abstractmethod + async def update_status(self, status: str) -> None: + """Update the main status text.""" + + @abstractmethod + async def show_alert(self, alert_type: str, detail: str) -> None: + """Display an alert notification.""" + + @abstractmethod + async def update_stats( + self, + messages: int = 0, + alerts: int = 0, + subjects: int = 0, + ) -> None: + """Update session statistics display.""" + + @abstractmethod + async def stop(self) -> None: + """Close the HUD window.""" + + +# ─── Mock Implementations ─────────────────────────────────────────────────── + + +_MOCK_USERNAMES = ["Player123", "xXGamerXx", "CoolKid99", "TestUser42", "RobloxFan"] +_MOCK_MESSAGES = [ + "hey whats up", + "anyone want to play", + "lol nice", + "gg", + "can you add me", + "where are you from", + "how old are you", + "this game is fun", + "brb", + "im bored lets go somewhere else", +] + + +class MockScreenCapture(ScreenCapture): + """Generates synthetic frames for testing.""" + + async def start(self) -> None: + pass + + async def capture_frame(self, region: Optional[tuple[int, int, int, int]] = None) -> CapturedFrame: + w, h = (region[2], region[3]) if region else (1920, 1080) + # Small random image to keep memory low + image = np.random.randint(0, 255, (h // 10, w // 10, 3), dtype=np.uint8) + return CapturedFrame(image=image, timestamp=time.time(), width=w, height=h) + + async def stop(self) -> None: + pass + + +class MockAudioCapture(AudioCapture): + """Generates synthetic audio for testing.""" + + async def start(self, device: Optional[str] = None, sample_rate: int = 16000) -> None: + self._sample_rate = sample_rate + + async def read_chunk(self, duration_s: float = 5.0) -> CapturedAudio: + n_samples = int(self._sample_rate * duration_s) + samples = np.random.randn(n_samples).astype(np.float32) * 0.01 + return CapturedAudio( + samples=samples, + sample_rate=self._sample_rate, + timestamp=time.time(), + duration_s=duration_s, + ) + + async def stop(self) -> None: + pass + + +class MockVideoRecorder(VideoRecorder): + """Simulates video recording to a temp file.""" + + def __init__(self): + self._path: Optional[Path] = None + self._frame_count = 0 + + async def start(self, output_path: Path, fps: int = 30) -> None: + self._path = output_path + self._frame_count = 0 + + async def write_frame(self, frame: CapturedFrame) -> None: + self._frame_count += 1 + + async def stop(self) -> Path: + path = self._path or Path("/tmp/mock_recording.mp4") + # Write a small marker file so tests can verify + path.parent.mkdir(parents=True, exist_ok=True) + path.write_bytes(b"MOCK_VIDEO:" + str(self._frame_count).encode()) + return path + + +class MockOCREngine(OCREngine): + """Returns synthetic chat messages as OCR results.""" + + async def initialize(self, model: str = "default") -> None: + pass + + async def extract_text(self, frame: CapturedFrame) -> OCRResult: + # Simulate 1-3 chat lines per frame + n = random.randint(1, 3) + texts = [] + for _ in range(n): + user = random.choice(_MOCK_USERNAMES) + msg = random.choice(_MOCK_MESSAGES) + texts.append(f"[{user}]: {msg}") + + return OCRResult( + texts=texts, + confidences=[random.uniform(0.7, 0.99) for _ in texts], + boxes=[(100, 100 + i * 30, 400, 20) for i in range(len(texts))], + timestamp=frame.timestamp, + frame_width=frame.width, + frame_height=frame.height, + ) + + async def shutdown(self) -> None: + pass + + +class MockSpeechToText(SpeechToText): + """Returns synthetic transcript segments.""" + + async def initialize(self, model: str = "base") -> None: + pass + + async def transcribe(self, audio: CapturedAudio) -> list[TranscriptSegment]: + # Simulate 0-2 segments per chunk + n = random.randint(0, 2) + segments = [] + for i in range(n): + segments.append( + TranscriptSegment( + text=random.choice(_MOCK_MESSAGES), + speaker_label=f"speaker_{random.randint(0, 3)}", + confidence=random.uniform(0.6, 0.99), + diarization_confidence=random.uniform(0.5, 0.95), + start_time=audio.timestamp + i * 2.0, + end_time=audio.timestamp + (i + 1) * 2.0, + timestamp=audio.timestamp, + ) + ) + return segments + + async def shutdown(self) -> None: + pass + + +class MockSpeakerDiarizer(SpeakerDiarizer): + """Returns synthetic speaker assignments.""" + + async def initialize(self) -> None: + pass + + async def diarize(self, audio: CapturedAudio) -> list[TranscriptSegment]: + n = random.randint(1, 3) + segments = [] + for i in range(n): + segments.append( + TranscriptSegment( + speaker_label=f"speaker_{random.randint(0, 3)}", + start_time=audio.timestamp + i * 1.5, + end_time=audio.timestamp + (i + 1) * 1.5, + timestamp=audio.timestamp, + diarization_confidence=random.uniform(0.6, 0.95), + ) + ) + return segments + + async def shutdown(self) -> None: + pass + + +class MockHUDOverlay(HUDOverlay): + """No-op HUD for headless/test mode.""" + + async def start(self) -> None: + pass + + async def update_status(self, status: str) -> None: + pass + + async def show_alert(self, alert_type: str, detail: str) -> None: + pass + + async def update_stats(self, messages: int = 0, alerts: int = 0, subjects: int = 0) -> None: + pass + + async def stop(self) -> None: + pass diff --git a/client/nightshade_client/pipeline/buffer.py b/client/nightshade_client/pipeline/buffer.py new file mode 100644 index 0000000..9eb362a --- /dev/null +++ b/client/nightshade_client/pipeline/buffer.py @@ -0,0 +1,139 @@ +""" +RingBuffer: fixed-size circular buffer for video frames (dashcam mode). + +Frames are appended continuously at a configurable FPS. When a trigger +event occurs, lock_and_save() can extract the pre-event window and write +a marker file (Part A stub — real video encoding comes in Part B via +VideoRecorder). +""" + +import asyncio +import collections +import json +import logging +import time +from pathlib import Path + +from nightshade_client.pipeline.base import CapturedFrame + +logger = logging.getLogger(__name__) + + +class RingBuffer: + """Fixed-size circular buffer for CapturedFrame objects. + + The buffer holds at most ``max_seconds * fps`` frames, discarding the + oldest frame automatically when full (deque maxlen behaviour). + + Thread safety: all public methods are async and protected by an + asyncio.Lock so they are safe to call from multiple coroutines. + """ + + def __init__(self, max_seconds: int = 300, fps: int = 2) -> None: + self._max_frames: int = max_seconds * fps + self._fps: int = fps + self._buffer: collections.deque[CapturedFrame] = collections.deque( + maxlen=self._max_frames + ) + self._lock: asyncio.Lock = asyncio.Lock() + logger.debug( + "RingBuffer initialised: max_seconds=%d, fps=%d, max_frames=%d", + max_seconds, + fps, + self._max_frames, + ) + + # ── Public API ────────────────────────────────────────────────────────── + + async def append(self, frame: CapturedFrame) -> None: + """Add a frame to the ring buffer, evicting the oldest if full.""" + async with self._lock: + self._buffer.append(frame) + + async def lock_and_save( + self, + pre_seconds: float, + post_seconds: float, + output_path: Path, + fps: int = 2, + ) -> Path: + """Freeze the buffer, extract the pre-event window, and save a marker. + + Part A implementation: takes a snapshot of the most-recent + ``pre_seconds * fps`` frames and writes their timestamps plus + metadata to a JSON marker file. The real video encoding (NVENC / + OpenCV) is deferred to Part B where a VideoRecorder implementation + is available. + + After saving the pre-event snapshot the lock is released so the + live capture pipeline can continue appending frames for the + post_seconds window. Collecting the post-event frames is also a + Part B responsibility because it requires re-acquiring the buffer + after a timed delay. + + Args: + pre_seconds: How many seconds of footage before the trigger. + post_seconds: How many seconds of footage after the trigger + (recorded in metadata; collection is Part B). + output_path: Destination path for the marker / video file. + fps: Frame rate used when converting seconds to frame + counts. Defaults to the buffer's configured FPS. + + Returns: + The output_path that was written. + """ + pre_frame_count = int(pre_seconds * fps) + + async with self._lock: + # Snapshot the tail of the deque (most-recent frames). + snapshot = list(self._buffer) + + # Take only the last pre_frame_count frames. + pre_frames = snapshot[-pre_frame_count:] if pre_frame_count > 0 else [] + + output_path.parent.mkdir(parents=True, exist_ok=True) + + metadata = { + "trigger_time": time.time(), + "pre_seconds": pre_seconds, + "post_seconds": post_seconds, + "fps": fps, + "pre_frame_count_requested": pre_frame_count, + "pre_frame_count_captured": len(pre_frames), + "total_frames_in_buffer": len(snapshot), + "frame_timestamps": [f.timestamp for f in pre_frames], + "note": ( + "Part A stub — real video encoding via VideoRecorder " + "is implemented in Part B." + ), + } + + marker_path = output_path.with_suffix(".json") + marker_path.write_text(json.dumps(metadata, indent=2)) + + logger.info( + "RingBuffer: saved pre-event marker — %d frames (%.1fs), path=%s", + len(pre_frames), + pre_seconds, + marker_path, + ) + + return output_path + + # ── Properties / utilities ────────────────────────────────────────────── + + @property + def frame_count(self) -> int: + """Current number of frames held in the buffer.""" + return len(self._buffer) + + @property + def max_frames(self) -> int: + """Maximum capacity of the buffer in frames.""" + return self._max_frames + + def clear(self) -> None: + """Clear all frames from the buffer (not async — no lock needed for a + single atomic deque.clear() call in CPython).""" + self._buffer.clear() + logger.debug("RingBuffer: cleared.") diff --git a/client/nightshade_client/pipeline/orchestrator.py b/client/nightshade_client/pipeline/orchestrator.py new file mode 100644 index 0000000..41805b5 --- /dev/null +++ b/client/nightshade_client/pipeline/orchestrator.py @@ -0,0 +1,365 @@ +""" +Pipeline orchestrator — wires capture sources, processors, and streams together. + +The main async loop that runs the entire client pipeline. +""" + +import asyncio +import logging +import signal +from pathlib import Path +from typing import Optional + +from nightshade_client.config import ClientSettings +from nightshade_client.connection import ApiKeyMetadata, ChannelManager, HealthChecker +from nightshade_client.pipeline.base import ( + AudioCapture, + CapturedAudio, + CapturedFrame, + HUDOverlay, + MockAudioCapture, + MockHUDOverlay, + MockOCREngine, + MockScreenCapture, + MockSpeakerDiarizer, + MockSpeechToText, + MockVideoRecorder, + OCREngine, + ScreenCapture, + SpeakerDiarizer, + SpeechToText, + VideoRecorder, +) +from nightshade_client.pipeline.buffer import RingBuffer +from nightshade_client.session import SessionManager, SessionMonitor, SessionState +from nightshade_client.streams import ( + ChatStreamManager, + EvidenceStreamManager, + VoiceStreamManager, +) + +logger = logging.getLogger(__name__) + + +class PipelineOrchestrator: + """Wires capture → processing → streaming pipeline.""" + + def __init__( + self, + settings: ClientSettings, + *, + screen: Optional[ScreenCapture] = None, + audio: Optional[AudioCapture] = None, + video: Optional[VideoRecorder] = None, + ocr: Optional[OCREngine] = None, + stt: Optional[SpeechToText] = None, + diarizer: Optional[SpeakerDiarizer] = None, + hud: Optional[HUDOverlay] = None, + use_mocks: bool = False, + ): + self._settings = settings + + # Use mocks if requested or if no real impl provided + if use_mocks: + self._screen = screen or MockScreenCapture() + self._audio = audio or MockAudioCapture() + self._video = video or MockVideoRecorder() + self._ocr = ocr or MockOCREngine() + self._stt = stt or MockSpeechToText() + self._diarizer = diarizer or MockSpeakerDiarizer() + self._hud = hud or MockHUDOverlay() + else: + self._screen = screen or MockScreenCapture() + self._audio = audio or MockAudioCapture() + self._video = video or MockVideoRecorder() + self._ocr = ocr or MockOCREngine() + self._stt = stt or MockSpeechToText() + self._diarizer = diarizer or MockSpeakerDiarizer() + self._hud = hud or MockHUDOverlay() + + # Infrastructure (initialized in run()) + self._channel: Optional[ChannelManager] = None + self._auth: Optional[ApiKeyMetadata] = None + self._session_mgr: Optional[SessionManager] = None + self._monitor: Optional[SessionMonitor] = None + self._health: Optional[HealthChecker] = None + self._chat_stream: Optional[ChatStreamManager] = None + self._voice_stream: Optional[VoiceStreamManager] = None + self._evidence_stream: Optional[EvidenceStreamManager] = None + self._ring_buffer: Optional[RingBuffer] = None + + self._shutdown_event = asyncio.Event() + self._stats = {"messages": 0, "alerts": 0, "subjects": 0} + + async def run( + self, + game_name: str, + operator_id: Optional[str] = None, + universe_id: str = "", + place_id: str = "", + server_id: str = "", + roblox_account: str = "", + roblox_account_id: str = "", + ) -> None: + """Main pipeline entry point. + + Connect → start session → launch coroutines → wait for shutdown. + """ + op_id = operator_id or self._settings.operator_id + if not op_id: + raise ValueError("operator_id required (via arg or NIGHTSHADE_OPERATOR_ID)") + + logger.info("Pipeline starting | server=%s:%d game=%s", + self._settings.server_host, self._settings.server_port, game_name) + + # Setup signal handlers + loop = asyncio.get_running_loop() + for sig in (signal.SIGINT, signal.SIGTERM): + loop.add_signal_handler(sig, self._signal_shutdown) + + # Initialize infrastructure + self._auth = ApiKeyMetadata(self._settings.api_key) + self._channel = ChannelManager(self._settings) + await self._channel.connect() + + self._session_mgr = SessionManager(self._channel, self._auth, self._settings) + self._monitor = SessionMonitor(self._channel, self._auth, self._settings) + self._health = HealthChecker(self._channel) + self._chat_stream = ChatStreamManager(self._channel, self._auth, self._settings) + self._voice_stream = VoiceStreamManager(self._channel, self._auth, self._settings) + self._evidence_stream = EvidenceStreamManager(self._channel, self._auth, self._settings) + self._ring_buffer = RingBuffer( + max_seconds=self._settings.dashcam_buffer_seconds, + fps=max(1, 1000 // self._settings.capture_interval_ms), + ) + + # Register monitor callbacks + self._monitor.on_alert(self._handle_server_alert) + self._monitor.on_update(self._handle_server_update) + + # Initialize capture/processing + await self._screen.start() + await self._audio.start(device=self._settings.audio_device) + await self._ocr.initialize() + await self._stt.initialize(model=self._settings.whisper_model) + await self._diarizer.initialize() + await self._hud.start() + + try: + # Start session + session_id = await self._session_mgr.start_session( + operator_id=op_id, + game_name=game_name, + universe_id=universe_id, + place_id=place_id, + server_id=server_id, + roblox_account=roblox_account, + roblox_account_id=roblox_account_id, + ) + logger.info("Session started: %s", session_id) + + # Wire session_id into all stream managers + self._monitor.set_session_id(session_id) + self._chat_stream.set_session_id(session_id) + self._voice_stream.set_session_id(session_id) + self._evidence_stream.set_session_id(session_id) + + await self._hud.update_status(f"Session active: {session_id[:8]}...") + + # Launch all coroutines + async with asyncio.TaskGroup() as tg: + tg.create_task(self._health.start(), name="health") + tg.create_task(self._monitor.run(), name="monitor") + tg.create_task(self._chat_stream.run(), name="chat_stream") + tg.create_task(self._voice_stream.run(), name="voice_stream") + tg.create_task(self._capture_loop(), name="capture") + tg.create_task(self._audio_loop(), name="audio") + tg.create_task(self._shutdown_waiter(), name="shutdown") + + except* Exception as eg: + for exc in eg.exceptions: + if not isinstance(exc, asyncio.CancelledError): + logger.error("Pipeline error: %s", exc) + finally: + await self._cleanup(game_name) + + async def _capture_loop(self) -> None: + """Screen capture → OCR → chat stream loop.""" + interval = self._settings.capture_interval_ms / 1000.0 + while not self._shutdown_event.is_set(): + try: + frame = await self._screen.capture_frame(region=self._settings.ocr_region) + await self._ring_buffer.append(frame) + + result = await self._ocr.extract_text(frame) + for i, text in enumerate(result.texts): + # Parse "[username]: message" format from OCR + username, message = self._parse_chat_line(text) + if message: + conf = result.confidences[i] if i < len(result.confidences) else 0.0 + await self._chat_stream.enqueue( + username=username, + user_id="", + message=message, + source="game_chat", + confidence=conf, + ) + self._stats["messages"] += 1 + + # Update monitor with buffered count + if self._monitor: + self._monitor.update_buffered_count(self._chat_stream._queue.qsize()) + + await self._hud.update_stats( + messages=self._stats["messages"], + alerts=self._stats["alerts"], + subjects=self._stats["subjects"], + ) + + except asyncio.CancelledError: + raise + except Exception as exc: + logger.error("Capture loop error: %s", exc) + + await asyncio.sleep(interval) + + async def _audio_loop(self) -> None: + """Audio capture → STT → diarization → voice stream loop.""" + while not self._shutdown_event.is_set(): + try: + audio = await self._audio.read_chunk(duration_s=5.0) + + segments = await self._stt.transcribe(audio) + diarized = await self._diarizer.diarize(audio) + + # Merge diarization labels into transcript segments + for seg in segments: + # Find closest diarization segment by time + best_label = seg.speaker_label + best_conf = seg.diarization_confidence + for d in diarized: + if d.start_time <= seg.start_time <= d.end_time: + best_label = d.speaker_label + best_conf = d.diarization_confidence + break + + if seg.text.strip(): + await self._voice_stream.enqueue( + speaker_label=best_label, + user_id_guess="", + text=seg.text, + confidence=seg.confidence, + diarization_confidence=best_conf, + ) + + except asyncio.CancelledError: + raise + except Exception as exc: + logger.error("Audio loop error: %s", exc) + + async def _shutdown_waiter(self) -> None: + """Wait for shutdown signal, then cancel all tasks.""" + await self._shutdown_event.wait() + logger.info("Shutdown signal received, stopping pipeline...") + raise asyncio.CancelledError() + + def _signal_shutdown(self) -> None: + """Signal handler for SIGINT/SIGTERM.""" + logger.info("Signal received, initiating shutdown...") + self._shutdown_event.set() + + async def request_shutdown(self) -> None: + """Programmatic shutdown request.""" + self._shutdown_event.set() + + async def _handle_server_alert(self, alert) -> None: + """Handle AlertNotification from server.""" + logger.warning( + "Server alert: type=%s subject=%s risk=%.1f", + alert.alert_type, alert.subject_username, alert.risk_score, + ) + self._stats["alerts"] += 1 + await self._hud.show_alert(alert.alert_type, alert.trigger_detail) + + # Trigger dashcam save on alert + if self._ring_buffer: + try: + output = Path(f"/tmp/nightshade_dashcam_{alert.alert_id}.bin") + await self._ring_buffer.lock_and_save( + pre_seconds=30.0, + post_seconds=10.0, + output_path=output, + ) + logger.info("Dashcam saved: %s", output) + # Upload as evidence + if self._evidence_stream: + await self._evidence_stream.upload_file(output, file_type="recording") + except Exception as exc: + logger.error("Dashcam save failed: %s", exc) + + async def _handle_server_update(self, update) -> None: + """Handle SessionUpdate from server.""" + self._stats["subjects"] = update.active_subjects + await self._hud.update_stats( + messages=self._stats["messages"], + alerts=self._stats["alerts"], + subjects=self._stats["subjects"], + ) + + async def _cleanup(self, game_name: str) -> None: + """Graceful shutdown of all components.""" + logger.info("Cleaning up pipeline...") + + # Stop streams + for component in [self._chat_stream, self._voice_stream]: + if component: + try: + await component.stop() + except Exception as exc: + logger.error("Stream stop error: %s", exc) + + # Stop monitor and health + for component in [self._monitor, self._health]: + if component: + try: + await component.stop() + except Exception as exc: + logger.error("Component stop error: %s", exc) + + # End session + if self._session_mgr and self._session_mgr.state == SessionState.ACTIVE: + try: + stats = await self._session_mgr.end_session(reason="shutdown") + logger.info("Session ended: %s", stats) + except Exception as exc: + logger.error("End session error: %s", exc) + + # Cleanup capture/processing + for component in [self._screen, self._audio, self._ocr, self._stt, self._diarizer, self._hud]: + if component: + try: + stop = getattr(component, "stop", None) or getattr(component, "shutdown", None) + if stop: + await stop() + except Exception as exc: + logger.error("Cleanup error: %s", exc) + + # Close channel + if self._channel: + try: + await self._channel.close() + except Exception as exc: + logger.error("Channel close error: %s", exc) + + logger.info("Pipeline shutdown complete") + + @staticmethod + def _parse_chat_line(text: str) -> tuple[str, str]: + """Parse '[username]: message' format. Returns (username, message).""" + text = text.strip() + if text.startswith("[") and "]: " in text: + bracket_end = text.index("]: ") + username = text[1:bracket_end] + message = text[bracket_end + 3:] + return username, message + return "", text diff --git a/client/nightshade_client/processing/__init__.py b/client/nightshade_client/processing/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/client/nightshade_client/processing/diarize.py b/client/nightshade_client/processing/diarize.py new file mode 100644 index 0000000..e6158db --- /dev/null +++ b/client/nightshade_client/processing/diarize.py @@ -0,0 +1,13 @@ +"""Speaker diarization — Part B stub. + +Implement NeMoDiarizer using NeMo toolkit on Windows with GPU. +""" + +from nightshade_client.pipeline.base import ( # noqa: F401 + CapturedAudio, + MockSpeakerDiarizer, + SpeakerDiarizer, + TranscriptSegment, +) + +__all__ = ["SpeakerDiarizer", "TranscriptSegment", "CapturedAudio", "MockSpeakerDiarizer"] diff --git a/client/nightshade_client/processing/ocr.py b/client/nightshade_client/processing/ocr.py new file mode 100644 index 0000000..6415044 --- /dev/null +++ b/client/nightshade_client/processing/ocr.py @@ -0,0 +1,13 @@ +"""OCR processing — Part B stub. + +Implement PaddleOCREngine using PaddleOCR on Windows with GPU. +""" + +from nightshade_client.pipeline.base import ( # noqa: F401 + CapturedFrame, + MockOCREngine, + OCREngine, + OCRResult, +) + +__all__ = ["OCREngine", "OCRResult", "CapturedFrame", "MockOCREngine"] diff --git a/client/nightshade_client/processing/stt.py b/client/nightshade_client/processing/stt.py new file mode 100644 index 0000000..2b1d333 --- /dev/null +++ b/client/nightshade_client/processing/stt.py @@ -0,0 +1,13 @@ +"""Speech-to-text — Part B stub. + +Implement WhisperSTT using faster-whisper on Windows with GPU. +""" + +from nightshade_client.pipeline.base import ( # noqa: F401 + CapturedAudio, + MockSpeechToText, + SpeechToText, + TranscriptSegment, +) + +__all__ = ["SpeechToText", "TranscriptSegment", "CapturedAudio", "MockSpeechToText"] diff --git a/client/nightshade_client/session/__init__.py b/client/nightshade_client/session/__init__.py new file mode 100644 index 0000000..be2d5f9 --- /dev/null +++ b/client/nightshade_client/session/__init__.py @@ -0,0 +1,4 @@ +from nightshade_client.session.manager import SessionManager, SessionState +from nightshade_client.session.monitor import SessionMonitor + +__all__ = ["SessionManager", "SessionState", "SessionMonitor"] diff --git a/client/nightshade_client/session/manager.py b/client/nightshade_client/session/manager.py new file mode 100644 index 0000000..083ea4e --- /dev/null +++ b/client/nightshade_client/session/manager.py @@ -0,0 +1,226 @@ +""" +Session lifecycle manager with state machine and exponential backoff retry. +""" + +import asyncio +import enum +import logging +import random +from typing import Optional + +import grpc + +from nightshade_client.config import ClientSettings +from nightshade_client.connection.auth import ApiKeyMetadata +from nightshade_client.connection.channel import ChannelManager +from nightshade_client.grpc_stubs import nightshade_pb2 + +logger = logging.getLogger(__name__) + + +class SessionState(enum.Enum): + IDLE = "idle" + STARTING = "starting" + ACTIVE = "active" + ENDING = "ending" + ENDED = "ended" + + +class SessionManager: + """Manages session lifecycle with a strict state machine. + + Valid transitions: + IDLE -> STARTING -> ACTIVE (start_session success) + STARTING -> IDLE (start_session exhausted retries) + ACTIVE -> ENDING -> ENDED (end_session success) + Any -> IDLE (reset) + """ + + def __init__( + self, + channel_manager: ChannelManager, + auth: ApiKeyMetadata, + settings: ClientSettings, + ) -> None: + self._channel = channel_manager + self._auth = auth + self._settings = settings + self._state = SessionState.IDLE + self._session_id: Optional[str] = None + + # ── Properties ───────────────────────────────────────────────────────── + + @property + def state(self) -> SessionState: + """Current session state.""" + return self._state + + @property + def session_id(self) -> Optional[str]: + """Active session ID, or None if not yet started.""" + return self._session_id + + # ── Public API ───────────────────────────────────────────────────────── + + async def start_session( + self, + operator_id: str, + game_name: str, + universe_id: str = "", + place_id: str = "", + server_id: str = "", + roblox_account: str = "", + roblox_account_id: str = "", + ) -> str: + """Start a monitoring session and return the session_id. + + Transitions: IDLE -> STARTING -> ACTIVE + Retries with exponential backoff on transient gRPC errors. + Resets to IDLE and raises on exhausted retries. + """ + if self._state is not SessionState.IDLE: + raise RuntimeError( + f"Cannot start session from state {self._state.value!r}. " + "Call reset() first." + ) + + self._state = SessionState.STARTING + logger.info( + "Starting session for operator=%s game=%r", operator_id, game_name + ) + + request = nightshade_pb2.StartSessionRequest( + operator_id=operator_id, + game_name=game_name, + universe_id=universe_id, + place_id=place_id, + server_id=server_id, + roblox_account=roblox_account, + roblox_account_id=roblox_account_id, + ) + + delay = self._settings.reconnect_min_s + attempt = 0 + + while True: + attempt += 1 + try: + stub = self._channel.get_stub() + response = await stub.StartSession( + request, + metadata=self._auth.get_metadata(), + ) + + if response.status != "started": + raise RuntimeError( + f"Server rejected StartSession: status={response.status!r}" + ) + + self._session_id = response.session_id + self._state = SessionState.ACTIVE + logger.info( + "Session started: id=%s (attempt %d)", self._session_id, attempt + ) + return self._session_id + + except grpc.RpcError as exc: + code = exc.code() if hasattr(exc, "code") else "unknown" + logger.warning( + "StartSession attempt %d failed: code=%s details=%s", + attempt, + code, + exc.details() if hasattr(exc, "details") else exc, + ) + # Treat INVALID_ARGUMENT as non-retriable (bad request data) + if hasattr(exc, "code") and exc.code() is grpc.StatusCode.INVALID_ARGUMENT: + self._state = SessionState.IDLE + self._session_id = None + raise + + except Exception as exc: + logger.warning("StartSession attempt %d failed: %s", attempt, exc) + + # Cap delay and add jitter + jitter = random.uniform(0.0, 1.0) + wait = min(delay, self._settings.reconnect_max_s) + jitter + logger.info("Retrying StartSession in %.2fs (attempt %d)...", wait, attempt) + await asyncio.sleep(wait) + delay = min(delay * 2, self._settings.reconnect_max_s) + + # Stop after the equivalent of max_delay has been reached twice + # (avoids infinite loops against a hard-down server) + if delay >= self._settings.reconnect_max_s and attempt >= 8: + self._state = SessionState.IDLE + self._session_id = None + raise RuntimeError( + f"StartSession failed after {attempt} attempts. Giving up." + ) + + async def end_session(self, reason: str = "normal") -> dict: + """End the current session and return a final stats dict. + + Transitions: ACTIVE -> ENDING -> ENDED + """ + if self._state is not SessionState.ACTIVE: + raise RuntimeError( + f"Cannot end session from state {self._state.value!r}. " + "Session must be ACTIVE." + ) + if self._session_id is None: + raise RuntimeError("No session_id set despite ACTIVE state.") + + self._state = SessionState.ENDING + logger.info( + "Ending session id=%s reason=%r", self._session_id, reason + ) + + request = nightshade_pb2.EndSessionRequest( + session_id=self._session_id, + reason=reason, + ) + + try: + stub = self._channel.get_stub() + response = await stub.EndSession( + request, + metadata=self._auth.get_metadata(), + ) + except grpc.RpcError as exc: + logger.error( + "EndSession gRPC error: code=%s details=%s", + exc.code() if hasattr(exc, "code") else "unknown", + exc.details() if hasattr(exc, "details") else exc, + ) + # Still transition to ENDED — the session is over from our side + self._state = SessionState.ENDED + raise + + stats = _stats_to_dict(response.final_stats) + self._state = SessionState.ENDED + logger.info( + "Session ended: id=%s stats=%s", response.session_id, stats + ) + return stats + + def reset(self) -> None: + """Reset to IDLE (use after an error to allow a fresh start_session call).""" + logger.debug( + "SessionManager reset: %s -> IDLE (session_id=%s cleared)", + self._state.value, + self._session_id, + ) + self._state = SessionState.IDLE + self._session_id = None + + +# ── Helpers ──────────────────────────────────────────────────────────────── + + +def _stats_to_dict(stats: nightshade_pb2.SessionStats) -> dict: + """Convert a SessionStats protobuf message to a plain dict.""" + return { + "total_messages": stats.total_messages, + "total_alerts": stats.total_alerts, + "subjects_flagged": stats.subjects_flagged, + "evidence_files": stats.evidence_files, + } diff --git a/client/nightshade_client/session/monitor.py b/client/nightshade_client/session/monitor.py new file mode 100644 index 0000000..db6b14d --- /dev/null +++ b/client/nightshade_client/session/monitor.py @@ -0,0 +1,367 @@ +""" +Bidirectional stream monitor for a Nightshade session. + +Wraps MonitorSession with: + - Periodic heartbeat sends + - Queue-based outbound event dispatch (alerts, session updates) + - Callback-based inbound event dispatch (alerts, command acks, session updates) + - Auto-reconnect with exponential backoff on stream drop +""" + +import asyncio +import logging +import random +import time +from typing import Callable, Optional + +import grpc + +from nightshade_client.config import ClientSettings +from nightshade_client.connection.auth import ApiKeyMetadata +from nightshade_client.connection.channel import ChannelManager +from nightshade_client.grpc_stubs import nightshade_pb2 + +logger = logging.getLogger(__name__) + +# Sentinel placed on the send queue to signal the sender loop to exit cleanly +_STOP_SENTINEL = object() + + +class SessionMonitor: + """Wraps MonitorSession bidi stream with heartbeat and event dispatch. + + Usage:: + + monitor = SessionMonitor(channel_manager, auth, settings) + monitor.set_session_id(session_id) + monitor.on_alert(handle_alert) + monitor.on_command(handle_command_ack) + monitor.on_update(handle_session_update) + + asyncio.create_task(monitor.run()) + # ... later ... + await monitor.stop() + """ + + def __init__( + self, + channel_manager: ChannelManager, + auth: ApiKeyMetadata, + settings: ClientSettings, + ) -> None: + self._channel = channel_manager + self._auth = auth + self._settings = settings + + self._session_id: Optional[str] = None + self._running: bool = False + + # Event callbacks (registered before run()) + self._on_alert: Optional[Callable] = None + self._on_command: Optional[Callable] = None + self._on_update: Optional[Callable] = None + + # Heartbeat payload + self._messages_buffered: int = 0 + + # Outbound event queue (AlertReport wrapped in ClientEvent) + self._send_queue: asyncio.Queue = asyncio.Queue() + + # ── Configuration ────────────────────────────────────────────────────── + + def set_session_id(self, session_id: str) -> None: + """Set (or update) the session ID used in all outbound events.""" + self._session_id = session_id + logger.debug("SessionMonitor: session_id set to %s", session_id) + + def on_alert(self, callback: Callable) -> None: + """Register a callback for AlertNotification server events. + + Signature: callback(alert_notification: nightshade_pb2.AlertNotification) + May be a coroutine function; will be awaited if so. + """ + self._on_alert = callback + + def on_command(self, callback: Callable) -> None: + """Register a callback for CommandAck server events. + + Signature: callback(command_ack: nightshade_pb2.CommandAck) + May be a coroutine function; will be awaited if so. + """ + self._on_command = callback + + def on_update(self, callback: Callable) -> None: + """Register a callback for SessionUpdate server events. + + Signature: callback(session_update: nightshade_pb2.SessionUpdate) + May be a coroutine function; will be awaited if so. + """ + self._on_update = callback + + def update_buffered_count(self, count: int) -> None: + """Update the messages_buffered counter reported in heartbeats.""" + self._messages_buffered = count + + # ── Main Loop ────────────────────────────────────────────────────────── + + async def run(self) -> None: + """Open the MonitorSession bidi stream and keep it alive. + + Spawns _heartbeat_sender and _event_receiver as concurrent tasks. + On stream failure, backs off and reconnects automatically. + + Returns when stop() is called. + """ + if self._session_id is None: + raise RuntimeError( + "SessionMonitor.run() called without a session_id. " + "Call set_session_id() first." + ) + + self._running = True + delay = self._settings.reconnect_min_s + attempt = 0 + + while self._running: + attempt += 1 + logger.info( + "MonitorSession: opening stream (session=%s, attempt=%d)", + self._session_id, + attempt, + ) + + try: + stub = self._channel.get_stub() + + # grpc.aio bidi streams are opened by calling the stub method; + # the stream object is both an async iterator (reads) and + # supports write() / done_writing(). + stream = stub.MonitorSession( + metadata=self._auth.get_metadata() + ) + + sender_task = asyncio.create_task( + self._heartbeat_sender(stream), + name="monitor-heartbeat-sender", + ) + receiver_task = asyncio.create_task( + self._event_receiver(stream), + name="monitor-event-receiver", + ) + + # Block until one of the tasks finishes (error or stop signal) + done, pending = await asyncio.wait( + {sender_task, receiver_task}, + return_when=asyncio.FIRST_COMPLETED, + ) + + # Cancel the surviving task so we can reopen cleanly + for task in pending: + task.cancel() + try: + await task + except (asyncio.CancelledError, Exception): + pass + + # Surface any exception from the finished task + for task in done: + exc = task.exception() + if exc is not None and self._running: + logger.warning( + "MonitorSession task raised: %s: %s", + type(exc).__name__, + exc, + ) + + if not self._running: + # Clean stop requested + try: + await stream.done_writing() + except Exception: + pass + logger.info("MonitorSession: stopped cleanly.") + return + + # Abnormal exit — wait before reconnecting + jitter = random.uniform(0.0, 1.0) + wait = min(delay, self._settings.reconnect_max_s) + jitter + logger.info( + "MonitorSession: stream dropped, reconnecting in %.2fs...", wait + ) + await asyncio.sleep(wait) + delay = min(delay * 2, self._settings.reconnect_max_s) + + except grpc.RpcError as exc: + code = exc.code() if hasattr(exc, "code") else "unknown" + logger.warning( + "MonitorSession gRPC error (attempt %d): code=%s details=%s", + attempt, + code, + exc.details() if hasattr(exc, "details") else exc, + ) + if not self._running: + return + + jitter = random.uniform(0.0, 1.0) + wait = min(delay, self._settings.reconnect_max_s) + jitter + logger.info("MonitorSession: retrying in %.2fs...", wait) + await asyncio.sleep(wait) + delay = min(delay * 2, self._settings.reconnect_max_s) + + except Exception as exc: + logger.error( + "MonitorSession unexpected error (attempt %d): %s", attempt, exc + ) + if not self._running: + return + + jitter = random.uniform(0.0, 1.0) + wait = min(delay, self._settings.reconnect_max_s) + jitter + await asyncio.sleep(wait) + delay = min(delay * 2, self._settings.reconnect_max_s) + + async def stop(self) -> None: + """Signal the monitor loop to stop after the current stream iteration.""" + logger.info("SessionMonitor: stop requested.") + self._running = False + # Unblock the sender if it is waiting on the queue + await self._send_queue.put(_STOP_SENTINEL) + + # ── Internal Stream Handlers ──────────────────────────────────────────── + + async def _heartbeat_sender(self, stream) -> None: + """Send heartbeat ClientEvents on a fixed interval, plus queued alerts. + + Interleaves queue drains with timed heartbeats so alert sends are not + delayed by the heartbeat interval. + """ + interval = float(self._settings.heartbeat_interval_s) + + while self._running: + deadline = asyncio.get_event_loop().time() + interval + + # Drain any queued outbound events first + while not self._send_queue.empty(): + item = self._send_queue.get_nowait() + if item is _STOP_SENTINEL: + return + # item is an AlertReport proto + event = nightshade_pb2.ClientEvent( + session_id=self._session_id, + alert=item, + ) + await stream.write(event) + logger.debug( + "Sent AlertReport for subject=%s", item.subject_username + ) + + # Send heartbeat + hb = nightshade_pb2.Heartbeat( + timestamp=_now_millis(), + messages_buffered=self._messages_buffered, + ) + event = nightshade_pb2.ClientEvent( + session_id=self._session_id, + heartbeat=hb, + ) + await stream.write(event) + logger.debug( + "Heartbeat sent: buffered=%d ts=%d", + self._messages_buffered, + hb.timestamp, + ) + + # Sleep until next heartbeat, but check stop/queue periodically + remaining = deadline - asyncio.get_event_loop().time() + if remaining > 0: + try: + # Use wait_for to check the queue with a timeout so we + # can send queued alerts without waiting the full interval + item = await asyncio.wait_for( + self._send_queue.get(), timeout=remaining + ) + if item is _STOP_SENTINEL: + return + # Got a queued alert mid-sleep; send it immediately + event = nightshade_pb2.ClientEvent( + session_id=self._session_id, + alert=item, + ) + await stream.write(event) + logger.debug( + "Sent queued AlertReport for subject=%s", + item.subject_username, + ) + except asyncio.TimeoutError: + # No queued event — normal heartbeat cadence continues + pass + + async def _event_receiver(self, stream) -> None: + """Receive ServerEvents from the bidi stream and dispatch to callbacks.""" + async for event in stream: + if not self._running: + break + + payload_type = event.WhichOneof("payload") + + if payload_type == "alert_notification": + logger.debug( + "Received AlertNotification: alert_id=%s type=%s score=%.2f", + event.alert_notification.alert_id, + event.alert_notification.alert_type, + event.alert_notification.risk_score, + ) + if self._on_alert is not None: + await _maybe_await(self._on_alert, event.alert_notification) + + elif payload_type == "command_ack": + logger.debug( + "Received CommandAck: command_id=%s accepted=%s", + event.command_ack.command_id, + event.command_ack.accepted, + ) + if self._on_command is not None: + await _maybe_await(self._on_command, event.command_ack) + + elif payload_type == "session_update": + logger.debug( + "Received SessionUpdate: status=%s active_subjects=%d", + event.session_update.status, + event.session_update.active_subjects, + ) + if self._on_update is not None: + await _maybe_await(self._on_update, event.session_update) + + else: + logger.warning( + "Received ServerEvent with unknown payload type: %r", payload_type + ) + + # ── Outbound API ─────────────────────────────────────────────────────── + + async def send_alert(self, alert_report: nightshade_pb2.AlertReport) -> None: + """Enqueue an AlertReport to be sent through the bidi stream. + + Thread-safe; can be called from any coroutine while run() is active. + """ + await self._send_queue.put(alert_report) + logger.debug( + "AlertReport enqueued: subject=%s type=%s", + alert_report.subject_username, + alert_report.alert_type, + ) + + +# ── Helpers ──────────────────────────────────────────────────────────────── + + +def _now_millis() -> int: + """Return current Unix time in milliseconds.""" + return int(time.time() * 1000) + + +async def _maybe_await(callback: Callable, *args) -> None: + """Call callback(*args). Await it if it's a coroutine function.""" + result = callback(*args) + if asyncio.iscoroutine(result): + await result diff --git a/client/nightshade_client/streams/__init__.py b/client/nightshade_client/streams/__init__.py new file mode 100644 index 0000000..4145df8 --- /dev/null +++ b/client/nightshade_client/streams/__init__.py @@ -0,0 +1,9 @@ +from nightshade_client.streams.chat_stream import ChatStreamManager +from nightshade_client.streams.voice_stream import VoiceStreamManager +from nightshade_client.streams.evidence_stream import EvidenceStreamManager + +__all__ = [ + "ChatStreamManager", + "VoiceStreamManager", + "EvidenceStreamManager", +] diff --git a/client/nightshade_client/streams/chat_stream.py b/client/nightshade_client/streams/chat_stream.py new file mode 100644 index 0000000..f330757 --- /dev/null +++ b/client/nightshade_client/streams/chat_stream.py @@ -0,0 +1,185 @@ +""" +ChatStreamManager: batches chat messages and streams them to the Nightshade server. +""" + +import asyncio +import logging +import time +from typing import Optional + +from nightshade_client.config import ClientSettings +from nightshade_client.connection import ChannelManager, ApiKeyMetadata +from nightshade_client.grpc_stubs import nightshade_pb2 + +logger = logging.getLogger(__name__) + + +class ChatStreamManager: + """Batches chat messages and streams them to the server. + + Messages are accumulated in an asyncio.Queue. The run() loop drains + the queue in batches based on batch_size or batch_timeout, then sends + each batch via the StreamChatMessages client-streaming RPC. + """ + + def __init__( + self, + channel_manager: ChannelManager, + auth: ApiKeyMetadata, + settings: ClientSettings, + ) -> None: + self._channel = channel_manager + self._auth = auth + self._settings = settings + self._queue: asyncio.Queue[nightshade_pb2.ChatMessageUpload] = asyncio.Queue() + self._session_id: Optional[str] = None + self._running: bool = False + + # ── Public API ────────────────────────────────────────────────────────── + + def set_session_id(self, session_id: str) -> None: + """Set the session ID to embed in every outgoing message.""" + self._session_id = session_id + + async def enqueue( + self, + username: str, + user_id: str, + message: str, + source: str = "game_chat", + confidence: float = 1.0, + risk_flags: Optional[list[str]] = None, + ) -> None: + """Add a chat message to the send queue. + + The message is built into a ChatMessageUpload proto here so that + the timestamp captured is as close to the observed event as possible. + """ + proto = nightshade_pb2.ChatMessageUpload( + session_id=self._session_id or "", + username=username, + user_id=user_id, + message=message, + source=source, + confidence=confidence, + risk_flags=risk_flags or [], + timestamp=int(time.time() * 1000), + ) + await self._queue.put(proto) + logger.debug( + "Enqueued chat message from %s (queue size ~%d)", + username, + self._queue.qsize(), + ) + + async def run(self) -> None: + """Main loop: collect batches and stream them to the server. + + Batch criteria: settings.chat_batch_size messages OR + settings.chat_batch_timeout_s seconds elapsed — whichever comes first. + Each batch is sent as a single StreamChatMessages RPC call. + The loop exits when stop() is called and the queue is empty. + """ + self._running = True + logger.info("ChatStreamManager: run loop started.") + + while self._running or not self._queue.empty(): + batch = await self._collect_batch() + if batch: + await self._send_batch(batch) + + logger.info("ChatStreamManager: run loop finished.") + + async def stop(self) -> None: + """Signal the run loop to stop and flush any remaining messages.""" + logger.info("ChatStreamManager: stopping (will flush remaining messages).") + self._running = False + + # Drain whatever is left in the queue. + remaining: list[nightshade_pb2.ChatMessageUpload] = [] + while not self._queue.empty(): + try: + remaining.append(self._queue.get_nowait()) + except asyncio.QueueEmpty: + break + + if remaining: + logger.info( + "ChatStreamManager: flushing %d remaining messages.", len(remaining) + ) + await self._send_batch(remaining) + + # ── Internal helpers ──────────────────────────────────────────────────── + + async def _collect_batch( + self, + ) -> list[nightshade_pb2.ChatMessageUpload]: + """Collect up to chat_batch_size messages within chat_batch_timeout_s. + + Blocks waiting for the first item, then greedily pulls subsequent + items from the queue without blocking until the batch is full or + the timeout expires. + """ + batch: list[nightshade_pb2.ChatMessageUpload] = [] + deadline = ( + asyncio.get_event_loop().time() + self._settings.chat_batch_timeout_s + ) + + # Wait for at least one message (with timeout so run() can check + # _running periodically even when the queue stays empty). + try: + first = await asyncio.wait_for( + self._queue.get(), + timeout=self._settings.chat_batch_timeout_s, + ) + batch.append(first) + except asyncio.TimeoutError: + return batch + + # Greedily collect more messages until batch_size or deadline. + while len(batch) < self._settings.chat_batch_size: + remaining_time = deadline - asyncio.get_event_loop().time() + if remaining_time <= 0: + break + try: + msg = await asyncio.wait_for( + self._queue.get(), timeout=remaining_time + ) + batch.append(msg) + except asyncio.TimeoutError: + break + + return batch + + async def _send_batch( + self, batch: list[nightshade_pb2.ChatMessageUpload] + ) -> None: + """Stream a batch to the server via client-streaming RPC.""" + stub = self._channel.get_stub() + + async def message_iterator(): + for msg in batch: + yield msg + + try: + ack: nightshade_pb2.StreamAck = await stub.StreamChatMessages( + message_iterator(), + metadata=self._auth(), + ) + logger.info( + "ChatStreamManager: batch of %d sent, server ack: %d received, errors: %s", + len(batch), + ack.messages_received, + ack.errors or [], + ) + if ack.errors: + logger.warning( + "ChatStreamManager: server reported errors: %s", list(ack.errors) + ) + except Exception as exc: + logger.error( + "ChatStreamManager: failed to send batch of %d messages: %s", + len(batch), + exc, + exc_info=True, + ) diff --git a/client/nightshade_client/streams/evidence_stream.py b/client/nightshade_client/streams/evidence_stream.py new file mode 100644 index 0000000..46ccc12 --- /dev/null +++ b/client/nightshade_client/streams/evidence_stream.py @@ -0,0 +1,142 @@ +""" +EvidenceStreamManager: uploads evidence files (screenshots, recordings, +chat logs) in fixed-size chunks via the UploadEvidence RPC. +""" + +import asyncio +import hashlib +import logging +import math +from pathlib import Path +from typing import Optional + +from nightshade_client.config import ClientSettings +from nightshade_client.connection import ChannelManager, ApiKeyMetadata +from nightshade_client.grpc_stubs import nightshade_pb2 + +logger = logging.getLogger(__name__) + + +class EvidenceStreamManager: + """Uploads evidence files in chunks to the server. + + Files are read into memory, split into chunks of + settings.evidence_chunk_size bytes, and streamed via the + UploadEvidence client-streaming RPC. The full-file SHA-256 hash is + attached to the final chunk so the server can verify integrity. + """ + + def __init__( + self, + channel_manager: ChannelManager, + auth: ApiKeyMetadata, + settings: ClientSettings, + ) -> None: + self._channel = channel_manager + self._auth = auth + self._settings = settings + self._session_id: Optional[str] = None + + # ── Public API ────────────────────────────────────────────────────────── + + def set_session_id(self, session_id: str) -> None: + """Set the session ID to embed in every outgoing chunk.""" + self._session_id = session_id + + async def upload_file( + self, + path: Path, + file_type: str = "screenshot", + ) -> dict: + """Read a file from disk, chunk it, and stream to the server. + + Args: + path: Path to the file on disk. + file_type: One of "screenshot", "recording", "chat_log". + + Returns: + dict with keys: file_id, sha256_verified, stored_path + """ + logger.info( + "EvidenceStreamManager: reading %s (%s)", path, file_type + ) + data = path.read_bytes() + return await self.upload_bytes(data, filename=path.name, file_type=file_type) + + async def upload_bytes( + self, + data: bytes, + filename: str, + file_type: str = "screenshot", + ) -> dict: + """Upload raw bytes as evidence. + + Args: + data: Raw file content. + filename: Logical filename to report to the server. + file_type: One of "screenshot", "recording", "chat_log". + + Returns: + dict with keys: file_id, sha256_verified, stored_path + """ + sha256_hex = hashlib.sha256(data).hexdigest() + chunk_size = self._settings.evidence_chunk_size + total_chunks = max(1, math.ceil(len(data) / chunk_size)) + + logger.info( + "EvidenceStreamManager: uploading '%s' (%d bytes, %d chunks, sha256=%s...)", + filename, + len(data), + total_chunks, + sha256_hex[:12], + ) + + stub = self._channel.get_stub() + + async def chunk_iterator(): + for idx in range(total_chunks): + start = idx * chunk_size + end = start + chunk_size + chunk_data = data[start:end] + + # Only include the hash on the last chunk so the server + # can verify the complete file once all chunks arrive. + is_last = idx == total_chunks - 1 + + yield nightshade_pb2.EvidenceChunk( + session_id=self._session_id or "", + filename=filename, + chunk_data=chunk_data, + chunk_index=idx, + total_chunks=total_chunks, + sha256_hash=sha256_hex if is_last else "", + file_type=file_type, + ) + + try: + response: nightshade_pb2.UploadResponse = await stub.UploadEvidence( + chunk_iterator(), + metadata=self._auth(), + ) + result = { + "file_id": response.file_id, + "sha256_verified": response.sha256_verified, + "stored_path": response.stored_path, + } + logger.info( + "EvidenceStreamManager: upload complete — file_id=%s, " + "sha256_verified=%s, stored_path=%s", + result["file_id"], + result["sha256_verified"], + result["stored_path"], + ) + return result + + except Exception as exc: + logger.error( + "EvidenceStreamManager: failed to upload '%s': %s", + filename, + exc, + exc_info=True, + ) + raise diff --git a/client/nightshade_client/streams/voice_stream.py b/client/nightshade_client/streams/voice_stream.py new file mode 100644 index 0000000..cc65548 --- /dev/null +++ b/client/nightshade_client/streams/voice_stream.py @@ -0,0 +1,180 @@ +""" +VoiceStreamManager: batches voice transcript segments and streams them to the +Nightshade server. +""" + +import asyncio +import logging +import time +from typing import Optional + +from nightshade_client.config import ClientSettings +from nightshade_client.connection import ChannelManager, ApiKeyMetadata +from nightshade_client.grpc_stubs import nightshade_pb2 + +logger = logging.getLogger(__name__) + + +class VoiceStreamManager: + """Batches voice transcripts and streams them to the server. + + Mirrors the structure of ChatStreamManager but operates on + VoiceTranscriptUpload protos and uses the StreamVoiceTranscripts RPC. + Batch parameters are driven by settings.voice_batch_size and + settings.voice_batch_timeout_s. + """ + + def __init__( + self, + channel_manager: ChannelManager, + auth: ApiKeyMetadata, + settings: ClientSettings, + ) -> None: + self._channel = channel_manager + self._auth = auth + self._settings = settings + self._queue: asyncio.Queue[nightshade_pb2.VoiceTranscriptUpload] = ( + asyncio.Queue() + ) + self._session_id: Optional[str] = None + self._running: bool = False + + # ── Public API ────────────────────────────────────────────────────────── + + def set_session_id(self, session_id: str) -> None: + """Set the session ID to embed in every outgoing transcript.""" + self._session_id = session_id + + async def enqueue( + self, + speaker_label: str, + user_id_guess: str, + text: str, + confidence: float = 1.0, + diarization_confidence: float = 1.0, + risk_flags: Optional[list[str]] = None, + ) -> None: + """Add a voice transcript segment to the send queue. + + The proto is built immediately so the timestamp reflects the time the + segment was observed, not when it is eventually transmitted. + """ + proto = nightshade_pb2.VoiceTranscriptUpload( + session_id=self._session_id or "", + speaker_label=speaker_label, + user_id_guess=user_id_guess, + text=text, + confidence=confidence, + diarization_confidence=diarization_confidence, + risk_flags=risk_flags or [], + timestamp=int(time.time() * 1000), + ) + await self._queue.put(proto) + logger.debug( + "Enqueued voice transcript from %s (queue size ~%d)", + speaker_label, + self._queue.qsize(), + ) + + async def run(self) -> None: + """Main loop: collect batches and stream them to the server. + + Batch criteria: settings.voice_batch_size transcripts OR + settings.voice_batch_timeout_s seconds elapsed. + The loop exits after stop() is called and the queue is fully drained. + """ + self._running = True + logger.info("VoiceStreamManager: run loop started.") + + while self._running or not self._queue.empty(): + batch = await self._collect_batch() + if batch: + await self._send_batch(batch) + + logger.info("VoiceStreamManager: run loop finished.") + + async def stop(self) -> None: + """Signal the run loop to stop and flush any remaining transcripts.""" + logger.info("VoiceStreamManager: stopping (will flush remaining transcripts).") + self._running = False + + remaining: list[nightshade_pb2.VoiceTranscriptUpload] = [] + while not self._queue.empty(): + try: + remaining.append(self._queue.get_nowait()) + except asyncio.QueueEmpty: + break + + if remaining: + logger.info( + "VoiceStreamManager: flushing %d remaining transcripts.", + len(remaining), + ) + await self._send_batch(remaining) + + # ── Internal helpers ──────────────────────────────────────────────────── + + async def _collect_batch( + self, + ) -> list[nightshade_pb2.VoiceTranscriptUpload]: + """Collect up to voice_batch_size transcripts within voice_batch_timeout_s.""" + batch: list[nightshade_pb2.VoiceTranscriptUpload] = [] + deadline = ( + asyncio.get_event_loop().time() + self._settings.voice_batch_timeout_s + ) + + try: + first = await asyncio.wait_for( + self._queue.get(), + timeout=self._settings.voice_batch_timeout_s, + ) + batch.append(first) + except asyncio.TimeoutError: + return batch + + while len(batch) < self._settings.voice_batch_size: + remaining_time = deadline - asyncio.get_event_loop().time() + if remaining_time <= 0: + break + try: + msg = await asyncio.wait_for( + self._queue.get(), timeout=remaining_time + ) + batch.append(msg) + except asyncio.TimeoutError: + break + + return batch + + async def _send_batch( + self, batch: list[nightshade_pb2.VoiceTranscriptUpload] + ) -> None: + """Stream a batch to the server via client-streaming RPC.""" + stub = self._channel.get_stub() + + async def message_iterator(): + for msg in batch: + yield msg + + try: + ack: nightshade_pb2.StreamAck = await stub.StreamVoiceTranscripts( + message_iterator(), + metadata=self._auth(), + ) + logger.info( + "VoiceStreamManager: batch of %d sent, server ack: %d received, errors: %s", + len(batch), + ack.messages_received, + ack.errors or [], + ) + if ack.errors: + logger.warning( + "VoiceStreamManager: server reported errors: %s", list(ack.errors) + ) + except Exception as exc: + logger.error( + "VoiceStreamManager: failed to send batch of %d transcripts: %s", + len(batch), + exc, + exc_info=True, + ) diff --git a/client/proto/nightshade.proto b/client/proto/nightshade.proto new file mode 120000 index 0000000..f6c38a5 --- /dev/null +++ b/client/proto/nightshade.proto @@ -0,0 +1 @@ +../../server/proto/nightshade.proto \ No newline at end of file diff --git a/client/pyproject.toml b/client/pyproject.toml new file mode 100644 index 0000000..c76470c --- /dev/null +++ b/client/pyproject.toml @@ -0,0 +1,32 @@ +[build-system] +requires = ["setuptools>=68.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "nightshade-client" +version = "0.1.0" +description = "Nightshade monitoring client SDK" +requires-python = ">=3.11" +dependencies = [ + "grpcio>=1.60.0", + "grpcio-tools>=1.60.0", + "pydantic>=2.0", + "pydantic-settings>=2.0", + "numpy>=1.26", +] + +[project.optional-dependencies] +dev = [ + "pytest>=8.0", + "pytest-asyncio>=0.23", + "pytest-cov>=4.0", + "grpcio-testing>=1.60.0", +] + +[tool.setuptools.packages.find] +include = ["nightshade_client*"] + +[tool.pytest.ini_options] +asyncio_mode = "auto" +asyncio_default_fixture_loop_scope = "function" +testpaths = ["tests"] diff --git a/client/requirements-windows.txt b/client/requirements-windows.txt new file mode 100644 index 0000000..b21c466 --- /dev/null +++ b/client/requirements-windows.txt @@ -0,0 +1,9 @@ +# Windows + NVIDIA GPU dependencies (Part B) +# Install AFTER requirements.txt on a Windows machine with CUDA +paddlepaddle-gpu>=2.6 +paddleocr>=2.7 +faster-whisper>=1.0 +nemo_toolkit[asr]>=1.22 +sounddevice>=0.4 +mss>=9.0 +dearpygui>=1.10 diff --git a/client/requirements.txt b/client/requirements.txt new file mode 100644 index 0000000..5c545cd --- /dev/null +++ b/client/requirements.txt @@ -0,0 +1,12 @@ +# Platform-independent dependencies +grpcio>=1.60.0 +grpcio-tools>=1.60.0 +pydantic>=2.0 +pydantic-settings>=2.0 +numpy>=1.26 + +# Dev / test +pytest>=8.0 +pytest-asyncio>=0.23 +pytest-cov>=4.0 +grpcio-testing>=1.60.0 diff --git a/client/tests/__init__.py b/client/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/client/tests/conftest.py b/client/tests/conftest.py new file mode 100644 index 0000000..a8a0172 --- /dev/null +++ b/client/tests/conftest.py @@ -0,0 +1,180 @@ +""" +Shared fixtures for the Nightshade client SDK test suite. + +Provides: +- test_settings: ClientSettings with known-good test values +- mock_grpc_server: function-scoped async gRPC server with a MockServicer +- channel_manager: ChannelManager wired to the mock server +- auth: ApiKeyMetadata with the test API key + +Note on scope: pytest-asyncio 0.25 with asyncio_mode=auto creates a new event +loop per test function. A session-scoped async fixture would run in a different +loop than the tests that use it, causing gRPC channel calls to hang. All async +fixtures are therefore function-scoped. +""" + +import asyncio +import hashlib +import uuid + +import grpc +import grpc.aio +import pytest +import pytest_asyncio + +from nightshade_client.config import ClientSettings +from nightshade_client.connection import ApiKeyMetadata, ChannelManager +from nightshade_client.grpc_stubs import nightshade_pb2, nightshade_pb2_grpc + + +# ── Mock Servicer ───────────────────────────────────────────────────────────── + + +class MockNightshadeServicer(nightshade_pb2_grpc.NightshadeServiceServicer): + """In-process mock implementation of NightshadeServiceServicer. + + Records calls so tests can assert on what arrived. + """ + + def __init__(self) -> None: + self.chat_messages_received: list[nightshade_pb2.ChatMessageUpload] = [] + self.voice_transcripts_received: list[nightshade_pb2.VoiceTranscriptUpload] = [] + self.evidence_chunks_received: list[nightshade_pb2.EvidenceChunk] = [] + self.heartbeats_received: list[nightshade_pb2.Heartbeat] = [] + self.session_started: bool = False + self.session_ended: bool = False + + async def StartSession(self, request, context): + self.session_started = True + session_id = str(uuid.uuid4()) + return nightshade_pb2.StartSessionResponse( + session_id=session_id, + status="started", + ) + + async def EndSession(self, request, context): + self.session_ended = True + stats = nightshade_pb2.SessionStats( + total_messages=len(self.chat_messages_received), + total_alerts=0, + subjects_flagged=0, + evidence_files=0, + ) + return nightshade_pb2.EndSessionResponse( + session_id=request.session_id, + final_stats=stats, + ) + + async def StreamChatMessages(self, request_iterator, context): + count = 0 + async for msg in request_iterator: + self.chat_messages_received.append(msg) + count += 1 + return nightshade_pb2.StreamAck(messages_received=count, errors=[]) + + async def StreamVoiceTranscripts(self, request_iterator, context): + count = 0 + async for msg in request_iterator: + self.voice_transcripts_received.append(msg) + count += 1 + return nightshade_pb2.StreamAck(messages_received=count, errors=[]) + + async def UploadEvidence(self, request_iterator, context): + chunks: list[nightshade_pb2.EvidenceChunk] = [] + async for chunk in request_iterator: + self.evidence_chunks_received.append(chunk) + chunks.append(chunk) + + # Verify SHA-256 on last chunk + all_data = b"".join(c.chunk_data for c in chunks) + computed = hashlib.sha256(all_data).hexdigest() + last_hash = chunks[-1].sha256_hash if chunks else "" + verified = computed == last_hash + + file_id = str(uuid.uuid4()) + return nightshade_pb2.UploadResponse( + file_id=file_id, + sha256_verified=verified, + stored_path=f"/evidence/{file_id}", + ) + + async def MonitorSession(self, request_iterator, context): + """Echo each received heartbeat back as a SessionUpdate ServerEvent.""" + async for event in request_iterator: + payload_type = event.WhichOneof("payload") + if payload_type == "heartbeat": + self.heartbeats_received.append(event.heartbeat) + update = nightshade_pb2.SessionUpdate( + session_id=event.session_id, + status="active", + active_subjects=0, + ) + yield nightshade_pb2.ServerEvent( + session_id=event.session_id, + session_update=update, + ) + + +# ── Fixtures ────────────────────────────────────────────────────────────────── + + +@pytest.fixture +def test_settings() -> ClientSettings: + """ClientSettings with stable test defaults (does not read .env or env vars).""" + return ClientSettings( + server_host="localhost", + server_port=50051, + api_key="test-key-123", + operator_id="00000000-0000-0000-0000-000000000001", + # Reduce timeouts so tests complete quickly + reconnect_min_s=0.1, + reconnect_max_s=1.0, + heartbeat_interval_s=1, + chat_batch_timeout_s=0.2, + voice_batch_timeout_s=0.2, + ) + + +@pytest_asyncio.fixture +async def mock_grpc_server(): + """Start a per-test in-process gRPC server on a random port. + + Yields a tuple of (servicer, port) so tests can inspect recorded calls + and know which port the server is listening on. + + Function-scoped so each test gets its own event loop and server instance, + which avoids cross-loop gRPC channel issues with pytest-asyncio 0.25. + """ + servicer = MockNightshadeServicer() + server = grpc.aio.server() + nightshade_pb2_grpc.add_NightshadeServiceServicer_to_server(servicer, server) + + port = server.add_insecure_port("[::]:0") + await server.start() + + yield servicer, port + + await server.stop(grace=0) + + +@pytest_asyncio.fixture +async def channel_manager(mock_grpc_server, test_settings): + """ChannelManager connected to the mock gRPC server. + + Updates test_settings.server_port to match the random port used by + mock_grpc_server so the channel targets the correct endpoint. + """ + _, port = mock_grpc_server + # Patch the port on the settings object directly + object.__setattr__(test_settings, "server_port", port) + + mgr = ChannelManager(test_settings) + await mgr.connect() + yield mgr + await mgr.close() + + +@pytest.fixture +def auth(test_settings) -> ApiKeyMetadata: + """ApiKeyMetadata configured with the test API key.""" + return ApiKeyMetadata(test_settings.api_key) diff --git a/client/tests/test_config.py b/client/tests/test_config.py new file mode 100644 index 0000000..ba76e5f --- /dev/null +++ b/client/tests/test_config.py @@ -0,0 +1,76 @@ +""" +Tests for nightshade_client.config.ClientSettings. +""" + +import pytest + +from nightshade_client.config import ClientSettings + + +def test_default_values(): + """ClientSettings should have sensible defaults without any env vars.""" + settings = ClientSettings( + _env_file=None, # suppress .env loading + ) + assert settings.server_host == "localhost" + assert settings.server_port == 50051 + assert settings.api_key == "" + assert settings.use_tls is False + assert settings.operator_id == "" + assert settings.capture_interval_ms == 500 + assert settings.whisper_model == "base" + assert settings.buffer_max_messages == 1000 + assert settings.heartbeat_interval_s == 10 + assert settings.evidence_chunk_size == 1_048_576 + assert settings.chat_batch_size == 50 + assert settings.log_level == "INFO" + + +def test_env_var_override(monkeypatch): + """NIGHTSHADE_* env vars should override defaults.""" + monkeypatch.setenv("NIGHTSHADE_SERVER_HOST", "10.0.0.1") + monkeypatch.setenv("NIGHTSHADE_SERVER_PORT", "9090") + monkeypatch.setenv("NIGHTSHADE_API_KEY", "env-key-abc") + monkeypatch.setenv("NIGHTSHADE_USE_TLS", "true") + + settings = ClientSettings(_env_file=None) + + assert settings.server_host == "10.0.0.1" + assert settings.server_port == 9090 + assert settings.api_key == "env-key-abc" + assert settings.use_tls is True + + +def test_env_file_loading(tmp_path): + """ClientSettings should read from a .env file when provided.""" + env_file = tmp_path / ".env" + env_file.write_text( + "NIGHTSHADE_SERVER_HOST=file-host\n" + "NIGHTSHADE_SERVER_PORT=12345\n" + "NIGHTSHADE_API_KEY=file-key-xyz\n" + ) + + settings = ClientSettings(_env_file=str(env_file)) + + assert settings.server_host == "file-host" + assert settings.server_port == 12345 + assert settings.api_key == "file-key-xyz" + + +def test_field_types(): + """Pydantic should coerce and validate field types correctly.""" + settings = ClientSettings( + server_port=8080, + use_tls=False, + capture_interval_ms=250, + evidence_chunk_size=512_000, + reconnect_min_s=2.5, + _env_file=None, + ) + + assert isinstance(settings.server_port, int) + assert isinstance(settings.use_tls, bool) + assert isinstance(settings.capture_interval_ms, int) + assert isinstance(settings.evidence_chunk_size, int) + assert isinstance(settings.reconnect_min_s, float) + assert isinstance(settings.server_host, str) diff --git a/client/tests/test_connection/__init__.py b/client/tests/test_connection/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/client/tests/test_connection/test_auth.py b/client/tests/test_connection/test_auth.py new file mode 100644 index 0000000..274b9ee --- /dev/null +++ b/client/tests/test_connection/test_auth.py @@ -0,0 +1,26 @@ +""" +Tests for nightshade_client.connection.ApiKeyMetadata. +""" + +import pytest + +from nightshade_client.connection import ApiKeyMetadata + + +def test_get_metadata_returns_correct_format(): + """get_metadata() should return [('x-api-key', key)] list.""" + auth = ApiKeyMetadata("my-secret-key") + metadata = auth.get_metadata() + + assert isinstance(metadata, list) + assert len(metadata) == 1 + key_name, key_value = metadata[0] + assert key_name == "x-api-key" + assert key_value == "my-secret-key" + + +def test_callable_interface(): + """Calling the object directly should return same result as get_metadata().""" + auth = ApiKeyMetadata("callable-key") + assert auth() == auth.get_metadata() + assert auth()[0] == ("x-api-key", "callable-key") diff --git a/client/tests/test_connection/test_channel.py b/client/tests/test_connection/test_channel.py new file mode 100644 index 0000000..6f51a97 --- /dev/null +++ b/client/tests/test_connection/test_channel.py @@ -0,0 +1,44 @@ +""" +Tests for nightshade_client.connection.ChannelManager. +""" + +import pytest + +from nightshade_client.connection import ChannelManager +from nightshade_client.grpc_stubs import nightshade_pb2_grpc + + +async def test_connect_and_close(mock_grpc_server, test_settings): + """connect() should set is_connected; close() should clear it.""" + _, port = mock_grpc_server + object.__setattr__(test_settings, "server_port", port) + + mgr = ChannelManager(test_settings) + assert not mgr.is_connected + + await mgr.connect() + assert mgr.is_connected + + await mgr.close() + assert not mgr.is_connected + + +async def test_get_stub_returns_stub(channel_manager): + """get_stub() should return a NightshadeServiceStub after connect().""" + stub = channel_manager.get_stub() + assert isinstance(stub, nightshade_pb2_grpc.NightshadeServiceStub) + + +async def test_context_manager(mock_grpc_server, test_settings): + """ChannelManager used as async context manager should auto-close.""" + _, port = mock_grpc_server + object.__setattr__(test_settings, "server_port", port) + + mgr = ChannelManager(test_settings) + async with mgr as cm: + assert cm.is_connected + stub = cm.get_stub() + assert stub is not None + + # After exiting the context, the channel should be closed + assert not mgr.is_connected diff --git a/client/tests/test_pipeline/__init__.py b/client/tests/test_pipeline/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/client/tests/test_pipeline/test_buffer.py b/client/tests/test_pipeline/test_buffer.py new file mode 100644 index 0000000..b99e33f --- /dev/null +++ b/client/tests/test_pipeline/test_buffer.py @@ -0,0 +1,67 @@ +""" +Tests for nightshade_client.pipeline.buffer.RingBuffer. +""" + +import json +import time + +import numpy as np +import pytest + +from nightshade_client.pipeline.base import CapturedFrame +from nightshade_client.pipeline.buffer import RingBuffer + + +def _make_frame(ts: float | None = None) -> CapturedFrame: + """Create a minimal CapturedFrame for testing.""" + image = np.zeros((10, 10, 3), dtype=np.uint8) + return CapturedFrame(image=image, timestamp=ts or time.time()) + + +async def test_append_and_frame_count(): + """append() should increase frame_count; buffer should not exceed maxlen.""" + buf = RingBuffer(max_seconds=5, fps=2) # max 10 frames + assert buf.frame_count == 0 + + for i in range(5): + await buf.append(_make_frame()) + + assert buf.frame_count == 5 + + # Overflow: add more frames than capacity + for i in range(10): + await buf.append(_make_frame()) + + # Should be capped at max_frames (10) + assert buf.frame_count == buf.max_frames + + +async def test_lock_and_save_creates_output_file(tmp_path): + """lock_and_save() should write a JSON marker file and return the output path.""" + buf = RingBuffer(max_seconds=10, fps=2) # max 20 frames + + # Add 4 frames with known timestamps + base_ts = 1_000_000.0 + for i in range(4): + await buf.append(_make_frame(ts=base_ts + i * 0.5)) + + output_path = tmp_path / "capture" / "clip.bin" + result = await buf.lock_and_save( + pre_seconds=2.0, + post_seconds=1.0, + output_path=output_path, + fps=2, + ) + + # The returned path should match what was passed in + assert result == output_path + + # A JSON marker file should have been created alongside + marker = output_path.with_suffix(".json") + assert marker.exists() + + data = json.loads(marker.read_text()) + assert data["pre_seconds"] == 2.0 + assert data["post_seconds"] == 1.0 + assert "frame_timestamps" in data + assert isinstance(data["frame_timestamps"], list) diff --git a/client/tests/test_pipeline/test_orchestrator.py b/client/tests/test_pipeline/test_orchestrator.py new file mode 100644 index 0000000..076727b --- /dev/null +++ b/client/tests/test_pipeline/test_orchestrator.py @@ -0,0 +1,39 @@ +""" +Tests for nightshade_client.pipeline.orchestrator.PipelineOrchestrator. +""" + +import pytest + +from nightshade_client.config import ClientSettings +from nightshade_client.pipeline.orchestrator import PipelineOrchestrator + + +def test_parse_chat_line_bracketed_format(): + """_parse_chat_line() should extract username and message from '[user]: msg'.""" + username, message = PipelineOrchestrator._parse_chat_line("[Alice]: hello there") + assert username == "Alice" + assert message == "hello there" + + +def test_parse_chat_line_no_bracket(): + """_parse_chat_line() should return empty username for non-bracketed input.""" + username, message = PipelineOrchestrator._parse_chat_line("just some raw text") + assert username == "" + assert message == "just some raw text" + + +def test_orchestrator_constructs_with_mocks(test_settings): + """PipelineOrchestrator should construct without error when use_mocks=True.""" + orch = PipelineOrchestrator(test_settings, use_mocks=True) + + # Verify all pipeline components are initialised + assert orch._screen is not None + assert orch._audio is not None + assert orch._video is not None + assert orch._ocr is not None + assert orch._stt is not None + assert orch._diarizer is not None + assert orch._hud is not None + # Infrastructure starts as None until run() is called + assert orch._channel is None + assert orch._session_mgr is None diff --git a/client/tests/test_session/__init__.py b/client/tests/test_session/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/client/tests/test_session/test_manager.py b/client/tests/test_session/test_manager.py new file mode 100644 index 0000000..34c1ac3 --- /dev/null +++ b/client/tests/test_session/test_manager.py @@ -0,0 +1,84 @@ +""" +Tests for nightshade_client.session.SessionManager. +""" + +import pytest + +from nightshade_client.session import SessionManager, SessionState + + +async def test_start_session_returns_session_id_and_transitions( + mock_grpc_server, channel_manager, auth, test_settings +): + """start_session() should return a non-empty session_id and move to ACTIVE.""" + mgr = SessionManager(channel_manager, auth, test_settings) + assert mgr.state is SessionState.IDLE + + session_id = await mgr.start_session( + operator_id="00000000-0000-0000-0000-000000000001", + game_name="TestGame", + ) + + assert isinstance(session_id, str) + assert len(session_id) > 0 + assert mgr.session_id == session_id + assert mgr.state is SessionState.ACTIVE + + +async def test_end_session_returns_stats_and_transitions( + mock_grpc_server, channel_manager, auth, test_settings +): + """end_session() should return a stats dict and transition to ENDED.""" + mgr = SessionManager(channel_manager, auth, test_settings) + + await mgr.start_session( + operator_id="00000000-0000-0000-0000-000000000001", + game_name="TestGame", + ) + assert mgr.state is SessionState.ACTIVE + + stats = await mgr.end_session(reason="normal") + + assert isinstance(stats, dict) + assert "total_messages" in stats + assert "total_alerts" in stats + assert "subjects_flagged" in stats + assert "evidence_files" in stats + assert mgr.state is SessionState.ENDED + + +async def test_start_twice_raises( + mock_grpc_server, channel_manager, auth, test_settings +): + """Calling start_session() while already ACTIVE should raise RuntimeError.""" + mgr = SessionManager(channel_manager, auth, test_settings) + + await mgr.start_session( + operator_id="00000000-0000-0000-0000-000000000001", + game_name="TestGame", + ) + + with pytest.raises(RuntimeError, match="Cannot start session"): + await mgr.start_session( + operator_id="00000000-0000-0000-0000-000000000001", + game_name="TestGame", + ) + + +async def test_reset_returns_to_idle( + mock_grpc_server, channel_manager, auth, test_settings +): + """reset() should move from any state back to IDLE and clear session_id.""" + mgr = SessionManager(channel_manager, auth, test_settings) + + await mgr.start_session( + operator_id="00000000-0000-0000-0000-000000000001", + game_name="TestGame", + ) + assert mgr.state is SessionState.ACTIVE + assert mgr.session_id is not None + + mgr.reset() + + assert mgr.state is SessionState.IDLE + assert mgr.session_id is None diff --git a/client/tests/test_session/test_monitor.py b/client/tests/test_session/test_monitor.py new file mode 100644 index 0000000..c462df9 --- /dev/null +++ b/client/tests/test_session/test_monitor.py @@ -0,0 +1,91 @@ +""" +Tests for nightshade_client.session.SessionMonitor. +""" + +import asyncio + +import pytest + +from nightshade_client.session import SessionMonitor +from nightshade_client.grpc_stubs import nightshade_pb2 + + +async def test_heartbeat_sent_to_mock_server( + mock_grpc_server, channel_manager, auth, test_settings +): + """SessionMonitor.run() should send at least one heartbeat to the server.""" + servicer, _ = mock_grpc_server + # Reset recorded heartbeats from previous tests + servicer.heartbeats_received.clear() + + monitor = SessionMonitor(channel_manager, auth, test_settings) + monitor.set_session_id("test-session-heartbeat") + + # Run the monitor for just long enough to send one heartbeat + # (heartbeat_interval_s is set to 1s in test_settings) + run_task = asyncio.create_task(monitor.run()) + try: + await asyncio.wait_for( + _wait_for_heartbeats(servicer, min_count=1), + timeout=5.0, + ) + finally: + await monitor.stop() + run_task.cancel() + try: + await run_task + except (asyncio.CancelledError, Exception): + pass + + assert len(servicer.heartbeats_received) >= 1 + + +async def test_server_event_dispatches_callback( + mock_grpc_server, channel_manager, auth, test_settings +): + """on_update() callback should be called when server sends a SessionUpdate.""" + servicer, _ = mock_grpc_server + servicer.heartbeats_received.clear() + + received_updates: list = [] + + async def on_update(update): + received_updates.append(update) + + monitor = SessionMonitor(channel_manager, auth, test_settings) + monitor.set_session_id("test-session-update") + monitor.on_update(on_update) + + # The mock server echoes each heartbeat back as a SessionUpdate, + # so once a heartbeat is sent we should receive a callback. + run_task = asyncio.create_task(monitor.run()) + try: + await asyncio.wait_for( + _wait_for_condition(lambda: len(received_updates) >= 1), + timeout=5.0, + ) + finally: + await monitor.stop() + run_task.cancel() + try: + await run_task + except (asyncio.CancelledError, Exception): + pass + + assert len(received_updates) >= 1 + assert hasattr(received_updates[0], "status") + + +# ── Helpers ────────────────────────────────────────────────────────────────── + + +async def _wait_for_heartbeats(servicer, min_count: int) -> None: + """Poll until servicer has received at least min_count heartbeats.""" + while len(servicer.heartbeats_received) < min_count: + await asyncio.sleep(0.05) + + +async def _wait_for_condition(predicate, poll_interval: float = 0.05) -> None: + """Poll until predicate() returns True.""" + while not predicate(): + await asyncio.sleep(poll_interval) diff --git a/client/tests/test_streams/__init__.py b/client/tests/test_streams/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/client/tests/test_streams/test_chat_stream.py b/client/tests/test_streams/test_chat_stream.py new file mode 100644 index 0000000..a3c91ce --- /dev/null +++ b/client/tests/test_streams/test_chat_stream.py @@ -0,0 +1,67 @@ +""" +Tests for nightshade_client.streams.ChatStreamManager. +""" + +import asyncio + +import pytest + +from nightshade_client.streams import ChatStreamManager + + +async def test_enqueue_and_batch_send( + mock_grpc_server, channel_manager, auth, test_settings +): + """enqueue() should add messages to the queue; run()/stop() should deliver them.""" + servicer, _ = mock_grpc_server + servicer.chat_messages_received.clear() + + mgr = ChatStreamManager(channel_manager, auth, test_settings) + mgr.set_session_id("chat-test-session") + + # Enqueue 3 messages before starting the run loop + await mgr.enqueue(username="Alice", user_id="1", message="hello world") + await mgr.enqueue(username="Bob", user_id="2", message="how are you") + await mgr.enqueue(username="Charlie", user_id="3", message="fine thanks") + + # Run briefly so the batch is sent, then stop + run_task = asyncio.create_task(mgr.run()) + await asyncio.sleep(0.05) # yield to event loop + await mgr.stop() + # Give the flush a moment to complete + await asyncio.wait_for(run_task, timeout=3.0) + + assert len(servicer.chat_messages_received) == 3 + + +async def test_messages_have_correct_fields( + mock_grpc_server, channel_manager, auth, test_settings +): + """Messages arriving at the server should carry the correct username and message.""" + servicer, _ = mock_grpc_server + servicer.chat_messages_received.clear() + + mgr = ChatStreamManager(channel_manager, auth, test_settings) + mgr.set_session_id("chat-field-test-session") + + await mgr.enqueue( + username="TestUser", + user_id="u-99", + message="are you alone", + source="game_chat", + confidence=0.95, + ) + + run_task = asyncio.create_task(mgr.run()) + await asyncio.sleep(0.05) + await mgr.stop() + await asyncio.wait_for(run_task, timeout=3.0) + + assert len(servicer.chat_messages_received) >= 1 + msg = servicer.chat_messages_received[-1] + assert msg.username == "TestUser" + assert msg.user_id == "u-99" + assert msg.message == "are you alone" + assert msg.source == "game_chat" + assert abs(msg.confidence - 0.95) < 1e-5 + assert msg.session_id == "chat-field-test-session" diff --git a/client/tests/test_streams/test_evidence_stream.py b/client/tests/test_streams/test_evidence_stream.py new file mode 100644 index 0000000..6fe0f07 --- /dev/null +++ b/client/tests/test_streams/test_evidence_stream.py @@ -0,0 +1,72 @@ +""" +Tests for nightshade_client.streams.EvidenceStreamManager. +""" + +import hashlib + +import pytest + +from nightshade_client.streams import EvidenceStreamManager + + +async def test_upload_bytes_correct_chunks_and_sha256( + mock_grpc_server, channel_manager, auth, test_settings +): + """upload_bytes() should send chunks with the correct SHA-256 on the last chunk.""" + servicer, _ = mock_grpc_server + servicer.evidence_chunks_received.clear() + + # Use a small chunk size so we get multiple chunks from a moderate payload + object.__setattr__(test_settings, "evidence_chunk_size", 64) + + mgr = EvidenceStreamManager(channel_manager, auth, test_settings) + mgr.set_session_id("evidence-test-session") + + data = b"A" * 200 # 200 bytes -> ceil(200/64) = 4 chunks + expected_sha = hashlib.sha256(data).hexdigest() + + result = await mgr.upload_bytes(data, filename="test.bin", file_type="screenshot") + + # Server should have received 4 chunks + assert len(servicer.evidence_chunks_received) == 4 + + # SHA-256 on last chunk should match + last_chunk = servicer.evidence_chunks_received[-1] + assert last_chunk.sha256_hash == expected_sha + + # Non-last chunks should have empty hash + for chunk in servicer.evidence_chunks_received[:-1]: + assert chunk.sha256_hash == "" + + # Server verified the hash + assert result["sha256_verified"] is True + assert result["file_id"] != "" + + +async def test_upload_file_works_with_temp_file( + mock_grpc_server, channel_manager, auth, test_settings, tmp_path +): + """upload_file() should read a file from disk and upload it successfully.""" + servicer, _ = mock_grpc_server + servicer.evidence_chunks_received.clear() + + object.__setattr__(test_settings, "evidence_chunk_size", 1_048_576) + + mgr = EvidenceStreamManager(channel_manager, auth, test_settings) + mgr.set_session_id("evidence-file-session") + + # Write a small temp file + evidence_file = tmp_path / "screenshot.png" + file_content = b"PNG_FAKE_DATA" * 10 + evidence_file.write_bytes(file_content) + + result = await mgr.upload_file(evidence_file, file_type="screenshot") + + assert result["sha256_verified"] is True + assert result["file_id"] != "" + assert result["stored_path"] != "" + + # All data in one chunk (file smaller than chunk_size) + assert len(servicer.evidence_chunks_received) == 1 + assert servicer.evidence_chunks_received[0].filename == "screenshot.png" + assert servicer.evidence_chunks_received[0].file_type == "screenshot"