-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_websocket.py
More file actions
242 lines (198 loc) Β· 9.33 KB
/
test_websocket.py
File metadata and controls
242 lines (198 loc) Β· 9.33 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
#!/usr/bin/env python3
"""
WebSocket client test for validating WebSocket API functionality
This script tests the WebSocket interface of the Image AI API and verifies
that it works compatibly with the REST API for task monitoring.
"""
import asyncio
import websockets
import json
import uuid
import sys
import argparse
import logging
from typing import Optional
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger("websocket-test")
# Constants
DEFAULT_HOST = "localhost"
DEFAULT_PORT = 8000
DEFAULT_TIMEOUT = 5.0
async def test_task_websocket(task_id: Optional[str] = None, host: str = DEFAULT_HOST, port: int = DEFAULT_PORT):
"""Test the task monitoring WebSocket endpoint
Args:
task_id: The ID of the task to monitor (or None for default test ID)
host: The hostname of the server
port: The port number of the server
"""
# Use provided task_id or a sample one
task_id = task_id or "183dbcb5-b583-46fd-94bf-16f80db98547"
uri = f"ws://{host}:{port}/ws/task/{task_id}"
try:
logger.info(f"π Connecting to {uri}")
async with websockets.connect(uri) as websocket:
logger.info("β
Connected successfully!")
# Listen for messages until the task completes
message_count = 0
while True: # Continue until task completes or error
try:
message = await asyncio.wait_for(websocket.recv(), timeout=DEFAULT_TIMEOUT)
data = json.loads(message)
message_count += 1
# Process the received message
process_message(message_count, message, data)
# Check if we should exit the loop
if should_end_monitoring(data):
break
except asyncio.TimeoutError:
logger.warning("β° Timeout waiting for message")
break
except websockets.exceptions.ConnectionClosedError as e:
logger.error(f"β Connection closed: {e}")
except websockets.exceptions.InvalidStatusCode as e:
logger.error(f"β Invalid status code: {e.status_code}")
except Exception as e:
logger.error(f"β Error: {e}", exc_info=True)
def process_message(message_count: int, raw_message: str, data: dict) -> None:
"""Process and display a message received from the WebSocket
Args:
message_count: The sequence number of this message
raw_message: The raw JSON message as a string
data: The parsed JSON data as a dictionary
"""
logger.info(f"\nπ¨ Message {message_count}:")
# Check for error messages
if 'error' in data:
logger.error(f" β Error: {data['error']}")
return
try:
# Print task ID
logger.info(f" Task ID: {data.get('task_id')}")
# Handle status (might be string or object with value)
status = data.get('status')
if isinstance(status, dict) and 'value' in status:
status = status['value']
logger.info(f" Status: {status}")
# Print progress and message
logger.info(f" Progress: {data.get('progress')}%")
logger.info(f" Message: {data.get('message')}")
# Print generation time if available
if data.get('generation_time') is not None:
logger.info(f" Generation Time: {data.get('generation_time'):.2f}s")
# Print output paths if available
if data.get('output_paths'):
paths = data.get('output_paths')
if paths:
logger.info(f" Output: {len(paths)} image(s)")
for path in paths:
logger.info(f" - {path}")
# Log completion status
if isinstance(status, str) and status in ['completed', 'failed', 'cancelled']:
logger.info(f" π Task finished with status: {status}")
elif isinstance(status, dict) and status.get('value') in ['completed', 'failed', 'cancelled']:
logger.info(f" π Task finished with status: {status.get('value')}")
except Exception as e:
logger.error(f" Error parsing data: {e}")
def should_end_monitoring(data: dict) -> bool:
"""Determine if we should stop monitoring based on the received data
Args:
data: The parsed JSON data from the WebSocket
Returns:
True if monitoring should end, False otherwise
"""
# End on error
if 'error' in data:
return True
# End on completed/failed/cancelled status
status = data.get('status')
if isinstance(status, str):
if status in ['completed', 'failed', 'cancelled']:
return True
elif isinstance(status, dict) and status.get('value'):
if status.get('value') in ['completed', 'failed', 'cancelled']:
return True
# Continue monitoring
return False
async def submit_task_and_monitor(host: str = DEFAULT_HOST, port: int = DEFAULT_PORT):
"""Submit a new task via REST API and monitor via WebSocket
Args:
host: The hostname of the server
port: The port number of the server
"""
import aiohttp
try:
# First submit a task via REST API
async with aiohttp.ClientSession() as session:
task_data = {
"prompt": "A beautiful cyberpunk cityscape with neon lights",
"model": "stable-diffusion-v1.5",
"width": 512,
"height": 512,
"num_inference_steps": 20,
"guidance_scale": 7.5,
"num_images_per_prompt": 1,
"scheduler": "euler_discrete"
}
logger.info("π€ Submitting task via REST API...")
api_url = f"http://{host}:{port}/generate"
async with session.post(api_url, json=task_data) as response:
if response.status != 200:
logger.error(f"β Failed to submit task: {await response.text()}")
return
result = await response.json()
task_id = result["task_id"]
logger.info(f"β
Task submitted successfully with ID: {task_id}")
# Now monitor via WebSocket
logger.info("π Now monitoring task via WebSocket...")
await test_task_websocket(task_id, host, port)
except Exception as e:
logger.error(f"β Error: {e}", exc_info=True)
async def test_performance_websocket(host: str = DEFAULT_HOST, port: int = DEFAULT_PORT, duration: int = 10):
"""Test the performance monitoring WebSocket endpoint
Args:
host: The hostname of the server
port: The port number of the server
duration: How many seconds to monitor performance
"""
uri = f"ws://{host}:{port}/ws/performance"
try:
logger.info(f"π Connecting to performance WebSocket at {uri}")
async with websockets.connect(uri) as websocket:
logger.info("β
Connected successfully!")
# Monitor for the specified duration
for i in range(duration):
try:
message = await asyncio.wait_for(websocket.recv(), timeout=DEFAULT_TIMEOUT)
data = json.loads(message)
logger.info(f"\nπ Performance data ({i+1}/{duration}):")
logger.info(f" CPU: {data['system']['cpu_percent']}%")
logger.info(f" Memory: {data['system']['memory_percent']}%")
logger.info(f" Queue Size: {data['application']['queue_size']}")
logger.info(f" Active Tasks: {data['application']['active_tasks']}")
except asyncio.TimeoutError:
logger.warning("β° Timeout waiting for performance data")
break
except Exception as e:
logger.error(f"β Error in performance monitor: {e}", exc_info=True)
def main():
"""Parse command line arguments and run the appropriate test"""
parser = argparse.ArgumentParser(description='Test WebSocket connection for AI Image Generation')
parser.add_argument('--host', default=DEFAULT_HOST, help='Server hostname')
parser.add_argument('--port', type=int, default=DEFAULT_PORT, help='Server port')
parser.add_argument('--task-id', help='Specific task ID to monitor')
parser.add_argument('--submit', action='store_true', help='Submit a new task and monitor')
parser.add_argument('--performance', action='store_true', help='Monitor system performance')
parser.add_argument('--duration', type=int, default=10, help='Duration for performance monitoring in seconds')
args = parser.parse_args()
if args.performance:
asyncio.run(test_performance_websocket(args.host, args.port, args.duration))
elif args.submit:
asyncio.run(submit_task_and_monitor(args.host, args.port))
else:
asyncio.run(test_task_websocket(args.task_id, args.host, args.port))
if __name__ == "__main__":
main()