forked from nbovee/tracr
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathserver.py
More file actions
669 lines (560 loc) · 24.4 KB
/
server.py
File metadata and controls
669 lines (560 loc) · 24.4 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
#!/usr/bin/env python
"""
Server-side implementation of the split computing architecture.
This module implements the server side of a split computing architecture.
It can be run in either networked mode (handling connections from clients) or local mode
(running experiments locally without network communication).
"""
import logging
import pickle
import socket
import sys
import time
import argparse
from contextlib import contextmanager
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Tuple, Any, Dict, Generator
import torch
# Add project root to path so we can import from src module
project_root = Path(__file__).resolve().parent
if str(project_root) not in sys.path:
sys.path.append(str(project_root))
from src.api import ( # noqa: E402
DataCompression,
DeviceManager,
ExperimentManager,
DeviceType,
start_logging_server,
shutdown_logging_server,
DataCompression, # noqa: F811
read_yaml_file,
)
from src.api.network.protocols import ( # noqa: E402
LENGTH_PREFIX_SIZE,
ACK_MESSAGE,
SERVER_COMPRESSION_SETTINGS,
SERVER_LISTEN_TIMEOUT,
SOCKET_TIMEOUT,
DEFAULT_PORT,
)
DEFAULT_CONFIG: Dict[str, Any] = {
"logging": {"log_file": "logs/server.log", "log_level": "INFO"}
}
# Start logging server
logging_server = start_logging_server(device=DeviceType.SERVER, config=DEFAULT_CONFIG)
logger = logging.getLogger("split_computing_logger")
def get_device(requested_device: str = "cuda") -> str:
"""Determine the appropriate device based on availability and request."""
requested_device = requested_device.lower()
if requested_device == "cpu":
logger.info("CPU device explicitly requested")
return "cpu"
if requested_device == "cuda" and torch.cuda.is_available():
logger.info("CUDA is available and will be used")
return "cuda"
# Check for MPS (Apple Silicon GPUs)
if (
requested_device == "mps"
and hasattr(torch.backends, "mps")
and torch.backends.mps.is_available()
):
logger.info("MPS (Apple Silicon GPU) is available and will be used")
return "mps"
# If we're here, requested GPU is not available - try alternatives
if requested_device in ("cuda", "gpu", "mps"):
# If any GPU was requested, try all available options in priority order
if torch.cuda.is_available():
logger.info(
f"{requested_device.upper()} requested but not available, using CUDA instead"
)
return "cuda"
if hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
logger.info(
f"{requested_device.upper()} requested but not available, using MPS (Apple Silicon GPU) instead"
)
return "mps"
logger.warning(
f"{requested_device.upper()} requested but no GPU available, falling back to CPU"
)
return "cpu"
# For any other requested device, fall back to CPU
logger.warning(
f"Requested device '{requested_device}' not recognized, falling back to CPU"
)
return "cpu"
@dataclass
class ServerMetrics:
"""Container for metrics collected during server operation."""
total_requests: int = 0
total_processing_time: float = 0.0
avg_processing_time: float = 0.0
def update(self, processing_time: float) -> None:
"""Update metrics with a new processing time measurement."""
self.total_requests += 1
self.total_processing_time += processing_time
self.avg_processing_time = self.total_processing_time / self.total_requests
class Server:
"""
Handles server operations for managing connections and processing data.
This class implements both networked and local modes:
- Networked mode: listens for client connections and processes data sent by clients
- Local mode: runs experiments locally using the provided configuration
"""
def __init__(
self, local_mode: bool = False, config_path: Optional[str] = None
) -> None:
"""Initialize the Server with specified mode and configuration."""
self.device_manager = DeviceManager()
self.experiment_manager: Optional[ExperimentManager] = None
self.server_socket: Optional[socket.socket] = None
self.local_mode = local_mode
self.config_path = config_path
self.metrics = ServerMetrics()
self.compress_data: Optional[DataCompression] = None
self._load_config_and_setup_device()
# Setup compression if in networked mode
if not local_mode:
self._setup_compression()
logger.debug("Server initialized in network mode")
else:
logger.debug("Server initialized in local mode")
def _load_config_and_setup_device(self) -> None:
"""Load configuration and set up device."""
if not self.config_path:
return
self.config = read_yaml_file(self.config_path)
requested_device = self.config.get("default", {}).get("device", "cuda")
self.config["default"]["device"] = get_device(requested_device)
def _setup_compression(self) -> None:
"""Initialize compression with minimal settings for optimal performance."""
self.compress_data = DataCompression(SERVER_COMPRESSION_SETTINGS)
logger.debug("Initialized compression with minimal settings")
def start(self) -> None:
"""Start the server in either networked or local mode."""
if self.local_mode:
self._run_local_experiment()
else:
self._run_networked_server()
def _run_local_experiment(self) -> None:
"""Run experiment locally on the server."""
if not self.config_path:
logger.error("Config path required for local mode")
return
try:
logger.info("Starting local experiment...")
self._setup_and_run_local_experiment()
logger.info("Local experiment completed successfully")
except Exception as e:
logger.error(f"Error running local experiment: {e}", exc_info=True)
def _setup_and_run_local_experiment(self) -> None:
"""Set up and run a local experiment based on configuration."""
from src.experiment_design.datasets.core.loaders import DatasetRegistry
import torch.utils.data
config = read_yaml_file(self.config_path)
self.experiment_manager = ExperimentManager(config, force_local=True)
experiment = self.experiment_manager.setup_experiment()
# Set up data loader
dataset_config = config.get("dataset", {})
dataloader_config = config.get("dataloader", {})
# Get the appropriate collate function if specified
collate_fn = self._get_collate_function(dataloader_config)
# Get dataset name - required parameter
dataset_name = dataset_config.get("name")
if not dataset_name:
logger.error("Dataset name not specified in config (required 'name' field)")
return
# Create a copy of the dataset config for loading
complete_config = dataset_config.copy()
# Add transform from dataloader config if not already specified
if "transform" not in complete_config and "transform" in dataloader_config:
complete_config["transform"] = dataloader_config.get("transform")
# Load dataset using registry
try:
# First register the dataset if needed
if DatasetRegistry.get_metadata(dataset_name) is None:
logger.info(f"Registering dataset '{dataset_name}'")
DatasetRegistry.register_dataset(dataset_name)
# Now load the dataset
dataset = DatasetRegistry.load(complete_config)
logger.info(f"Loaded dataset '{dataset_name}' successfully")
except Exception as e:
logger.error(f"Failed to load dataset '{dataset_name}': {e}")
raise # Re-raise to ensure the error is properly handled
# Create data loader
data_loader = torch.utils.data.DataLoader(
dataset,
batch_size=dataloader_config.get("batch_size"),
shuffle=dataloader_config.get("shuffle"),
num_workers=dataloader_config.get("num_workers"),
collate_fn=collate_fn,
)
# Attach data loader to experiment and run
experiment.data_loader = data_loader
experiment.run()
def _get_collate_function(self, dataloader_config: Dict[str, Any]) -> Optional[Any]:
"""
Get the collate function specified in the configuration.
Collate functions customize how individual data samples are combined
into batches for model processing.
"""
if not dataloader_config.get("collate_fn"):
return None
try:
from src.experiment_design.datasets.core.collate_fns import CollateRegistry
collate_fn_name = dataloader_config["collate_fn"]
collate_fn = CollateRegistry.get(collate_fn_name)
if not collate_fn:
logger.warning(
f"Collate function '{collate_fn_name}' not found in registry. "
"Using default collation."
)
return None
logger.debug(f"Using registered collate function: {collate_fn_name}")
return collate_fn
except ImportError as e:
logger.warning(
f"Failed to import collate functions: {e}. Using default collation."
)
return None
except KeyError:
logger.warning(
f"Collate function '{dataloader_config['collate_fn']}' not found. "
"Using default collation."
)
return None
def _run_networked_server(self) -> None:
"""Run server in networked mode, accepting client connections."""
# Get server device configuration
server_device = self.device_manager.get_device_by_type("SERVER")
if not server_device:
logger.error("No SERVER device configured. Cannot start server.")
return
if not server_device.is_reachable():
logger.error("SERVER device is not reachable. Check network connection.")
return
# Use experiment port for network communication
port = server_device.get_port()
if port is None:
logger.info(
f"No port configured for SERVER device, using DEFAULT_PORT={DEFAULT_PORT}"
)
port = DEFAULT_PORT
logger.info(f"Starting networked server on port {port}...")
try:
self._setup_socket(port)
self._accept_connections()
except KeyboardInterrupt:
logger.info("Server shutdown requested...")
except Exception as e:
logger.error(f"Server error: {e}", exc_info=True)
finally:
self.cleanup()
def _accept_connections(self) -> None:
"""
Accept and handle client connections in a continuous loop.
Uses socket timeout to allow for graceful shutdown on keyboard interrupt.
"""
while True:
try:
conn, addr = self.server_socket.accept()
# Set timeout on client socket for data operations
conn.settimeout(SOCKET_TIMEOUT)
logger.info(f"Connected by {addr}")
self.handle_connection(conn)
except socket.timeout:
# Handle timeout, allow checking for keyboard interrupt
continue
except ConnectionError as e:
logger.error(f"Connection error: {e}")
continue
def _setup_socket(self, port: int) -> None:
"""
Set up server socket with proper error handling.
Creates a socket that:
- Allows address reuse (SO_REUSEADDR)
- Has a timeout to enable graceful shutdown
- Listens on all interfaces (empty host string)
"""
try:
self.server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
# Set a timeout to allow graceful shutdown on keyboard interrupt
self.server_socket.settimeout(SERVER_LISTEN_TIMEOUT)
self.server_socket.bind(("", port))
self.server_socket.listen()
logger.info(f"Server is listening on port {port} (all interfaces)")
except Exception as e:
logger.error(f"Failed to create server socket: {e}")
raise
def _receive_config(self, conn: socket.socket) -> dict:
"""
Receive and parse configuration from client.
Implements a length-prefixed protocol for receiving structured data:
1. First 4 bytes indicate the total message length
2. Remaining bytes contain the serialized configuration
Returns:
The deserialized configuration dictionary
"""
try:
# Read the length prefix (4 bytes)
config_length_bytes = conn.recv(LENGTH_PREFIX_SIZE)
if (
not config_length_bytes
or len(config_length_bytes) != LENGTH_PREFIX_SIZE
):
logger.error("Failed to receive config length prefix")
return {}
config_length = int.from_bytes(config_length_bytes, "big")
logger.debug(f"Expecting config data of length {config_length} bytes")
if not self.compress_data:
logger.error("Compression not initialized")
return {}
# Receive the raw config data (no compression for config)
config_data = self.compress_data.receive_full_message(
conn=conn, expected_length=config_length
)
if not config_data:
logger.error("Failed to receive config data")
return {}
# Deserialize using pickle
try:
config = pickle.loads(config_data)
logger.debug("Successfully received and parsed configuration")
return config
except Exception as e:
logger.error(f"Failed to deserialize config: {e}")
return {}
except Exception as e:
logger.error(f"Error receiving config: {e}")
return {}
def _process_data(
self,
experiment: Any,
output: torch.Tensor,
original_size: Tuple[int, int],
split_layer_index: int,
) -> Tuple[Any, float]:
"""
Process received tensor data through the model and measure performance.
This is the core split computation function that:
1. Continues model execution from the specified split point
2. Returns both the processed result and the time taken
Args:
experiment: The experiment object that will process the data
output: The tensor output from the client
original_size: Original size information
split_layer_index: The index of the split layer
Returns:
Tuple of (processed_result, processing_time)
"""
server_start_time = time.time()
processed_result = experiment.process_data(
{"input": (output, original_size), "split_layer": split_layer_index}
)
return processed_result, time.time() - server_start_time
@contextmanager
def _safe_connection(self, conn: socket.socket) -> Generator[None, None, None]:
"""
Context manager for safely handling client connections.
Ensures proper exception handling and connection cleanup regardless
of how the connection processing terminates.
"""
try:
yield
except Exception as e:
logger.error(f"Error handling connection: {e}", exc_info=True)
finally:
try:
conn.close()
except Exception as e:
logger.debug(f"Error closing connection: {e}")
def handle_connection(self, conn: socket.socket) -> None:
"""
Handle an individual client connection for split computing.
The connection handling protocol follows these steps:
1. Receive experiment configuration from client
2. Initialize experiment based on received configuration
3. Send acknowledgment to client
4. Enter processing loop to handle tensor data
- Receive intermediate tensors from client
- Process tensors through the model from the split point
- Send results back to client
"""
with self._safe_connection(conn):
# Receive configuration from the client
config = self._receive_config(conn)
if not config:
logger.error("Failed to receive valid configuration from client")
return
# Update compression settings based on received config
self._update_compression(config)
# Initialize experiment based on received configuration
try:
self.experiment_manager = ExperimentManager(config)
experiment = self.experiment_manager.setup_experiment()
experiment.model.eval()
logger.info("Experiment initialized successfully with received config")
except Exception as e:
logger.error(f"Failed to initialize experiment: {e}")
return
# Cache torch.no_grad() context for inference
no_grad_context = torch.no_grad()
# Send acknowledgment to the client - must be exactly b"OK"
conn.sendall(ACK_MESSAGE)
logger.debug("Sent 'OK' acknowledgment to client")
# Process incoming data in a loop
while True:
try:
# Receive header - 8 bytes total (4 for split index, 4 for length)
header = conn.recv(LENGTH_PREFIX_SIZE * 2)
if not header or len(header) != LENGTH_PREFIX_SIZE * 2:
logger.info("Client disconnected or sent invalid header")
break
split_layer_index = int.from_bytes(
header[:LENGTH_PREFIX_SIZE], "big"
)
expected_length = int.from_bytes(header[LENGTH_PREFIX_SIZE:], "big")
logger.debug(
f"Received header: split_layer={split_layer_index}, data_length={expected_length}"
)
# Receive compressed data from client
if not self.compress_data:
logger.error("Compression not initialized")
break
compressed_data = self.compress_data.receive_full_message(
conn=conn, expected_length=expected_length
)
if not compressed_data:
logger.warning("Failed to receive compressed data from client")
break
logger.debug(
f"Received {len(compressed_data)} bytes of compressed data"
)
# Process the data
with no_grad_context:
# Decompress received data
output, original_size = self.compress_data.decompress_data(
compressed_data=compressed_data
)
# Process data using the experiment's model
processed_result, processing_time = self._process_data(
experiment=experiment,
output=output,
original_size=original_size,
split_layer_index=split_layer_index,
)
# Update metrics
self.metrics.update(processing_time)
logger.debug(f"Processed data in {processing_time:.4f}s")
# Compress the processed result to send back
compressed_result, result_size = self.compress_data.compress_data(
processed_result
)
# Send result back to client
self._send_result(
conn, result_size, processing_time, compressed_result
)
logger.debug(
f"Sent result of size {result_size} bytes back to client"
)
except Exception as e:
logger.error(f"Error processing client data: {e}", exc_info=True)
break
def _send_result(
self,
conn: socket.socket,
result_size: int,
processing_time: float,
compressed_result: bytes,
) -> None:
"""
Send the processed result back to the client using framed protocol.
The response protocol uses:
1. 4-byte length prefix for result size
2. 4-byte field for processing time (as padded string)
3. Variable-length compressed result data
"""
try:
# Send result size as header (4 bytes)
size_bytes = result_size.to_bytes(LENGTH_PREFIX_SIZE, "big")
conn.sendall(size_bytes)
# Send processing time as fixed-length bytes (4 bytes)
# Format as a string, pad/truncate to exactly 4 bytes
time_str = str(processing_time).ljust(LENGTH_PREFIX_SIZE)
time_bytes = time_str[:LENGTH_PREFIX_SIZE].encode()
conn.sendall(time_bytes)
# Send compressed result data
conn.sendall(compressed_result)
except Exception as e:
logger.error(f"Error sending result: {e}")
raise
def _update_compression(self, config: dict) -> None:
"""
Update compression settings from received configuration.
Compression settings affect the tradeoff between:
- Network bandwidth usage
- CPU utilization for compression/decompression
- Memory usage during transfer
"""
if "compression" in config:
logger.debug(f"Updating compression settings: {config['compression']}")
self.compress_data = DataCompression(config["compression"])
else:
logger.warning(
"No compression settings in config, keeping minimal settings"
)
def cleanup(self) -> None:
"""
Clean up server resources and close the socket.
Ensures graceful shutdown with proper resource release
and final metrics logging.
"""
logger.info("Starting server cleanup...")
if self.server_socket:
try:
self.server_socket.shutdown(socket.SHUT_RDWR)
self.server_socket.close()
self.server_socket = None
logger.info("Server socket cleaned up")
except Exception as e:
logger.error(f"Error during socket cleanup: {e}")
if logging_server:
shutdown_logging_server(logging_server)
# Log final metrics if any requests were processed
if self.metrics.total_requests > 0:
logger.info(
f"Final metrics: {self.metrics.total_requests} requests processed, "
f"average processing time: {self.metrics.avg_processing_time:.4f}s"
)
def parse_arguments() -> argparse.Namespace:
"""Parse command line arguments."""
parser = argparse.ArgumentParser(description="Run server for split computing")
parser.add_argument(
"-l",
"--local",
action="store_true",
help="Run experiment locally instead of as a network server",
)
parser.add_argument(
"-c",
"--config",
type=str,
help="Path to configuration file (required for local mode)",
required=False,
)
args = parser.parse_args()
if args.local and not args.config:
parser.error("--config is required when running in local mode")
return args
if __name__ == "__main__":
args = parse_arguments()
server = Server(local_mode=args.local, config_path=args.config)
try:
server.start()
except KeyboardInterrupt:
logger.info("Shutting down server due to keyboard interrupt...")
except Exception as e:
logger.error(f"Server crashed with error: {e}", exc_info=True)
finally:
server.cleanup()