Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 33 additions & 24 deletions api/internal/auth/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,36 +146,45 @@ func Middleware(jwtManager *JWTManager, userDB *db.UserDB) gin.HandlerFunc {
// Check if this is a WebSocket upgrade request
isWebSocket := c.GetHeader("Upgrade") == "websocket" && c.GetHeader("Connection") == "Upgrade"

// Extract token from Authorization header
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
// For WebSocket, abort without writing response (let upgrader handle it)
if isWebSocket {
c.AbortWithStatus(http.StatusUnauthorized)
var tokenString string

// For WebSocket connections, try query parameter first (browsers can't send custom headers)
if isWebSocket {
tokenString = c.Query("token")
}

// If no token from query parameter, try Authorization header
if tokenString == "" {
authHeader := c.GetHeader("Authorization")
if authHeader == "" {
// For WebSocket, abort without writing response (let upgrader handle it)
if isWebSocket {
c.AbortWithStatus(http.StatusUnauthorized)
return
}
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Authorization header required",
})
c.Abort()
return
}
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Authorization header required",
})
c.Abort()
return
}

// Check Bearer prefix
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || parts[0] != "Bearer" {
if isWebSocket {
c.AbortWithStatus(http.StatusUnauthorized)
// Check Bearer prefix
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || parts[0] != "Bearer" {
if isWebSocket {
c.AbortWithStatus(http.StatusUnauthorized)
return
}
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Invalid authorization header format. Use: Bearer <token>",
})
c.Abort()
return
}
c.JSON(http.StatusUnauthorized, gin.H{
"error": "Invalid authorization header format. Use: Bearer <token>",
})
c.Abort()
return
}

tokenString := parts[1]
tokenString = parts[1]
}

// Validate token
claims, err := jwtManager.ValidateToken(tokenString)
Expand Down
16 changes: 13 additions & 3 deletions ui/src/components/EnterpriseWebSocketProvider.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ export default function EnterpriseWebSocketProvider({
enableNotifications = true,
}: EnterpriseWebSocketProviderProps) {
const [notifications, setNotifications] = useState<Notification[]>([]);
const [reconnectDismissed, setReconnectDismissed] = useState(false); // Track if reconnect banner was dismissed

const addNotification = useCallback((message: string, severity: Notification['severity']) => {
const id = `${Date.now()}-${Math.random()}`;
Expand Down Expand Up @@ -213,12 +214,21 @@ export default function EnterpriseWebSocketProvider({
))}

{/* Connection status indicator (optional) */}
{!isConnected && reconnectAttempts > 0 && (
{!isConnected && reconnectAttempts > 0 && !reconnectDismissed && (
<Snackbar
open={true}
anchorOrigin={{ vertical: 'top', horizontal: 'center' }}
anchorOrigin={{ vertical: 'bottom', horizontal: 'left' }}
onClose={() => setReconnectDismissed(true)}
>
<Alert severity="warning" variant="filled">
<Alert
severity="info"
variant="outlined"
onClose={() => setReconnectDismissed(true)}
sx={{
backgroundColor: 'background.paper',
boxShadow: 1,
}}
>
Reconnecting... (Attempt {reconnectAttempts}/10)
</Alert>
</Snackbar>
Expand Down
24 changes: 16 additions & 8 deletions ui/src/components/WebSocketErrorBoundary.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ interface State {
hasError: boolean;
error: Error | null;
errorInfo: React.ErrorInfo | null;
dismissed: boolean; // Track if user has dismissed the error
}

export default class WebSocketErrorBoundary extends Component<Props, State> {
Expand All @@ -40,14 +41,14 @@ export default class WebSocketErrorBoundary extends Component<Props, State> {
hasError: false,
error: null,
errorInfo: null,
dismissed: false,
};
}

static getDerivedStateFromError(error: Error): State {
static getDerivedStateFromError(error: Error): Partial<State> {
return {
hasError: true,
error,
errorInfo: null,
};
}

Expand All @@ -70,10 +71,17 @@ export default class WebSocketErrorBoundary extends Component<Props, State> {
hasError: false,
error: null,
errorInfo: null,
dismissed: true, // Mark as dismissed
});
};

render() {
// If error was already dismissed, just render children without showing error UI
if (this.state.hasError && this.state.dismissed) {
console.warn('WebSocket error (dismissed):', this.state.error?.message);
return this.props.children;
}

if (this.state.hasError) {
// Use custom fallback if provided
if (this.props.fallback) {
Expand All @@ -96,7 +104,7 @@ export default class WebSocketErrorBoundary extends Component<Props, State> {
<AlertTitle>WebSocket Connection Error</AlertTitle>
<Typography variant="body2" paragraph>
There was an error with the real-time connection. The page will continue to work,
but live updates may be unavailable. You can try refreshing the page or reconnecting.
but live updates may be unavailable.
</Typography>

{this.props.showErrorDetails && this.state.error && (
Expand All @@ -110,18 +118,18 @@ export default class WebSocketErrorBoundary extends Component<Props, State> {
<Box sx={{ mt: 2, display: 'flex', gap: 1 }}>
<Button
variant="contained"
startIcon={<RefreshIcon />}
onClick={() => window.location.reload()}
onClick={this.handleReset}
size="small"
>
Reload Page
Continue Without Live Updates
</Button>
<Button
variant="outlined"
onClick={this.handleReset}
startIcon={<RefreshIcon />}
onClick={() => window.location.reload()}
size="small"
>
Try Again
Reload Page
</Button>
</Box>
</Alert>
Expand Down
14 changes: 8 additions & 6 deletions ui/src/hooks/useEnterpriseWebSocket.ts
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,14 @@ export function useEnterpriseWebSocket(
const getWebSocketUrl = useCallback(() => {
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
const host = window.location.host;
const token = localStorage.getItem('token');

// Include token as query parameter for WebSocket authentication
// Browsers cannot send custom headers in WebSocket connections
if (token) {
return `${protocol}//${host}/api/v1/ws/enterprise?token=${encodeURIComponent(token)}`;
}

return `${protocol}//${host}/api/v1/ws/enterprise`;
}, []);

Expand All @@ -65,12 +73,6 @@ export function useEnterpriseWebSocket(
}

try {
const token = localStorage.getItem('token');
if (!token) {
console.error('No authentication token found');
return;
}

const wsUrl = getWebSocketUrl();
// console.log(`[WebSocket] Connecting to ${wsUrl}`);

Expand Down
18 changes: 12 additions & 6 deletions ui/src/hooks/useWebSocket.ts
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,10 @@ export function useWebSocket({
* Hook for subscribing to session updates via WebSocket
*/
export function useSessionsWebSocket(onUpdate: (sessions: any[]) => void) {
const apiUrl = import.meta.env.VITE_API_URL || 'http://localhost:8000';
const wsUrl = apiUrl.replace(/^http/, 'ws') + '/api/v1/ws/sessions';
// Use window.location to connect through Vite proxy in dev, or directly in production
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
const token = localStorage.getItem('token');
const wsUrl = `${protocol}//${window.location.host}/api/v1/ws/sessions${token ? `?token=${encodeURIComponent(token)}` : ''}`;

return useWebSocket({
url: wsUrl,
Expand All @@ -151,8 +153,10 @@ export function useSessionsWebSocket(onUpdate: (sessions: any[]) => void) {
* Hook for subscribing to cluster metrics via WebSocket
*/
export function useMetricsWebSocket(onUpdate: (metrics: any) => void) {
const apiUrl = import.meta.env.VITE_API_URL || 'http://localhost:8000';
const wsUrl = apiUrl.replace(/^http/, 'ws') + '/api/v1/ws/cluster';
// Use window.location to connect through Vite proxy in dev, or directly in production
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
const token = localStorage.getItem('token');
const wsUrl = `${protocol}//${window.location.host}/api/v1/ws/cluster${token ? `?token=${encodeURIComponent(token)}` : ''}`;

return useWebSocket({
url: wsUrl,
Expand All @@ -170,8 +174,10 @@ export function useMetricsWebSocket(onUpdate: (metrics: any) => void) {
* Hook for subscribing to pod logs via WebSocket
*/
export function useLogsWebSocket(namespace: string, podName: string, onLog: (log: string) => void) {
const apiUrl = import.meta.env.VITE_API_URL || 'http://localhost:8000';
const wsUrl = apiUrl.replace(/^http/, 'ws') + `/api/v1/ws/logs/${namespace}/${podName}`;
// Use window.location to connect through Vite proxy in dev, or directly in production
const protocol = window.location.protocol === 'https:' ? 'wss:' : 'ws:';
const token = localStorage.getItem('token');
const wsUrl = `${protocol}//${window.location.host}/api/v1/ws/logs/${namespace}/${podName}${token ? `?token=${encodeURIComponent(token)}` : ''}`;

return useWebSocket({
url: wsUrl,
Expand Down
1 change: 1 addition & 0 deletions ui/vite.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ export default defineConfig({
'/api': {
target: 'http://localhost:8000',
changeOrigin: true,
ws: true, // Enable WebSocket proxying
},
'/webhooks': {
target: 'http://localhost:8000',
Expand Down
Loading