Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions project-words.txt
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,12 @@ milli
mocharc
openapitools
pkce
plotly
Plotly
Pyright
quickpick
refreshable
sess
toggleable
tombstoned
toolsai
Expand Down
10 changes: 7 additions & 3 deletions src/jupyter/colab-proxy-web-socket.ts
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ import {
COLAB_RUNTIME_PROXY_TOKEN_HEADER,
} from '../colab/headers';
import { warnOnDriveMount } from './drive-mount-warning';
import { injectPlotlyConfig } from './plotly-config';

/**
* Returns a class which extends {@link WebSocket}, adds Colab's custom headers,
* and intercepts {@link WebSocket.send} to warn users when on `drive.mount`
* execution.
* intercepts {@link WebSocket.send} to warn users when on `drive.mount`
* execution, and auto-configures Plotly renderer for Colab compatibility.
*/
export function colabProxyWebSocket(
vs: typeof vscode,
Expand Down Expand Up @@ -67,11 +68,14 @@ export function colabProxyWebSocket(
) {
warnOnDriveMount(vs, data);

// Auto-configure Plotly renderer for Colab compatibility
const modifiedData = injectPlotlyConfig(data);

if (options === undefined || typeof options === 'function') {
cb = options;
options = {};
}
super.send(data, options, cb);
super.send(modifiedData, options, cb);
}
};
}
121 changes: 121 additions & 0 deletions src/jupyter/colab-proxy-web-socket.unit.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,19 @@ import { expect } from 'chai';
import WebSocket from 'ws';
import { newVsCodeStub, VsCodeStub } from '../test/helpers/vscode';
import { colabProxyWebSocket } from './colab-proxy-web-socket';
import { resetConfiguredSessions } from './plotly-config';

describe('colabProxyWebSocket', () => {
const testToken = 'test-token';
let vsCodeStub: VsCodeStub;

beforeEach(() => {
vsCodeStub = newVsCodeStub();
resetConfiguredSessions();
});

afterEach(() => {
resetConfiguredSessions();
});

const tests = [
Expand Down Expand Up @@ -69,4 +75,119 @@ describe('colabProxyWebSocket', () => {
'X-Colab-Client-Agent': 'vscode',
});
}

/**
* Type for parsed Jupyter kernel messages in tests.
*/
interface ParsedMessage {
header: {
msg_type: string;
session?: string;
};
content?: {
code?: string;
};
}

describe('Plotly config injection', () => {
it('injects Plotly config on first execute_request', () => {
const sentData: string[] = [];
class MockWebSocket extends WebSocket {
constructor(_address: string | URL | null) {
super(null);
}
override send(data: string) {
sentData.push(data);
}
}

const ColabWebSocket = colabProxyWebSocket(
vsCodeStub.asVsCode(),
testToken,
MockWebSocket,
);
const ws = new ColabWebSocket('ws://example.com/socket');

const executeRequest = JSON.stringify({
header: { msg_type: 'execute_request', session: 'test-session' },
content: { code: 'print("hello")' },
});

ws.send(executeRequest, {});

expect(sentData.length).to.equal(1);
const parsed = JSON.parse(sentData[0]) as ParsedMessage;
expect(parsed.content?.code).to.include('plotly.io');
expect(parsed.content?.code).to.include('print("hello")');
});

it('does not inject Plotly config on subsequent requests for same session', () => {
const sentData: string[] = [];
class MockWebSocket extends WebSocket {
constructor(_address: string | URL | null) {
super(null);
}
override send(data: string) {
sentData.push(data);
}
}

const ColabWebSocket = colabProxyWebSocket(
vsCodeStub.asVsCode(),
testToken,
MockWebSocket,
);
const ws = new ColabWebSocket('ws://example.com/socket');

const executeRequest1 = JSON.stringify({
header: { msg_type: 'execute_request', session: 'test-session-2' },
content: { code: 'x = 1' },
});
const executeRequest2 = JSON.stringify({
header: { msg_type: 'execute_request', session: 'test-session-2' },
content: { code: 'y = 2' },
});

ws.send(executeRequest1, {});
ws.send(executeRequest2, {});

expect(sentData.length).to.equal(2);
// First request should have Plotly config
expect(
(JSON.parse(sentData[0]) as ParsedMessage).content?.code,
).to.include('plotly.io');
// Second request should NOT have Plotly config
expect((JSON.parse(sentData[1]) as ParsedMessage).content?.code).to.equal(
'y = 2',
);
});

it('does not modify non-execute_request messages', () => {
const sentData: string[] = [];
class MockWebSocket extends WebSocket {
constructor(_address: string | URL | null) {
super(null);
}
override send(data: string) {
sentData.push(data);
}
}

const ColabWebSocket = colabProxyWebSocket(
vsCodeStub.asVsCode(),
testToken,
MockWebSocket,
);
const ws = new ColabWebSocket('ws://example.com/socket');

const kernelInfoRequest = JSON.stringify({
header: { msg_type: 'kernel_info_request', session: 'test-session-3' },
});

ws.send(kernelInfoRequest, {});

expect(sentData.length).to.equal(1);
expect(sentData[0]).to.equal(kernelInfoRequest);
});
});
});
131 changes: 131 additions & 0 deletions src/jupyter/plotly-config.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/**
* @license
* Copyright 2025 Google LLC
* SPDX-License-Identifier: Apache-2.0
*/

import { z } from 'zod';

/**
* Python code to configure Plotly to use the 'plotly_mimetype' renderer.
* This makes Plotly visualizations work correctly in VS Code when connected
* to a Colab runtime. The code is wrapped in a try/except to gracefully
* handle cases where Plotly is not installed.
*/
const PLOTLY_CONFIG_CODE = `
try:
import plotly.io as pio
if pio.renderers.default != 'plotly_mimetype':
pio.renderers.default = 'plotly_mimetype'
except ImportError:
pass
`.trim();

/**
* Tracks kernel sessions that have already been configured.
* Uses session ID as key to ensure config runs once per kernel session.
*/
const configuredSessions = new Set<string>();

/**
* Clears the set of configured sessions.
* Useful for testing and when kernel sessions are restarted.
*/
export function resetConfiguredSessions(): void {
configuredSessions.clear();
}

/**
* Returns the number of currently configured sessions.
* Useful for testing.
*/
export function getConfiguredSessionCount(): number {
return configuredSessions.size;
}

/**
* Interface representing a Jupyter execute request message.
*/
interface JupyterExecuteRequestMessage {
header: {
msg_type: 'execute_request';
session: string;
};
content: {
code: string;
};
}

/**
* Zod schema for validating Jupyter execute request messages.
*/
const ExecuteRequestSchema = z.object({
header: z.object({
msg_type: z.literal('execute_request'),
session: z.string(),
}),
content: z.object({
code: z.string(),
}),
});

/**
* Type guard to check if a message is a Jupyter execute request.
*/
function isExecuteRequest(
message: unknown,
): message is JupyterExecuteRequestMessage {
return ExecuteRequestSchema.safeParse(message).success;
}

/**
* Injects Plotly configuration code into the first execute request for each
* kernel session. This ensures Plotly uses the 'plotly_mimetype' renderer
* which is compatible with VS Code's notebook rendering when connected to
* Colab runtimes.
*
* The injection is:
* - Idempotent: Only runs once per session
* - Safe: Wrapped in try/except, no-op if Plotly isn't installed
* - Non-invasive: Prepended to user code, doesn't affect execution
*
* @param rawJupyterMessage - The raw JSON string of a Jupyter kernel message
* @returns The potentially modified message string with Plotly config prepended
*/
export function injectPlotlyConfig(rawJupyterMessage: string): string {
if (!rawJupyterMessage) {
return rawJupyterMessage;
}

let parsedMessage: unknown;
try {
parsedMessage = JSON.parse(rawJupyterMessage) as unknown;
} catch {
// Not valid JSON, return as-is
return rawJupyterMessage;
}

if (!isExecuteRequest(parsedMessage)) {
return rawJupyterMessage;
}

const sessionId = parsedMessage.header.session;
if (configuredSessions.has(sessionId)) {
// Already configured this session
return rawJupyterMessage;
}

// Mark session as configured
configuredSessions.add(sessionId);

// Prepend Plotly configuration to user's code
const modifiedMessage = {
...parsedMessage,
content: {
...parsedMessage.content,
code: PLOTLY_CONFIG_CODE + '\n' + parsedMessage.content.code,
},
};

return JSON.stringify(modifiedMessage);
}
Loading