Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 25 additions & 2 deletions core/scripts/generate-python-exchanges.js
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,13 @@ function toClassName(name) {
.join('');
}

function toLegacyClassName(name) {
return name
.split('-')
.map(part => part.charAt(0).toUpperCase() + part.slice(1))
.join('');
}

function parseExchanges(content) {
const startIdx = content.indexOf('function createExchange(');
if (startIdx === -1) throw new Error('createExchange not found in exchange-factory.ts');
Expand Down Expand Up @@ -205,6 +212,9 @@ function generateClass(exchange) {

const appTs = fs.readFileSync(APP_TS_PATH, 'utf8');
const exchanges = parseExchanges(appTs);
const legacyAliases = exchanges
.map(ex => ({ legacyName: toLegacyClassName(ex.name), className: toClassName(ex.name) }))
.filter(alias => alias.legacyName !== alias.className);

const header = [
'# This file is auto-generated by core/scripts/generate-python-exchanges.js',
Expand All @@ -220,8 +230,17 @@ const header = [
].join('\n');

const body = exchanges.map(generateClass).join('\n\n\n');
const aliasBlock = legacyAliases.length
? [
'',
'',
'# Backwards-compatible aliases for exchange classes generated before underscore handling.',
...legacyAliases.map(alias => `${alias.legacyName} = ${alias.className}`),
'',
].join('\n')
: '\n';

fs.writeFileSync(OUTPUT_PATH, header + body + '\n');
fs.writeFileSync(OUTPUT_PATH, header + body + aliasBlock);
console.log(`Generated ${exchanges.length} exchange classes -> ${path.relative(process.cwd(), OUTPUT_PATH)}`);
for (const ex of exchanges) {
console.log(` ${toClassName(ex.name)} (exchange_name="${ex.name}")`);
Expand All @@ -230,7 +249,11 @@ for (const ex of exchanges) {
// ---------------------------------------------------------------------------
// Update __init__.py imports and __all__ to match generated exchanges
// ---------------------------------------------------------------------------
const classNames = exchanges.map(ex => toClassName(ex.name));
const classNames = exchanges.flatMap(ex => {
const className = toClassName(ex.name);
const legacyName = toLegacyClassName(ex.name);
return legacyName === className ? [className] : [className, legacyName];
});
const importList = classNames.join(', ');

let init = fs.readFileSync(INIT_PATH, 'utf8');
Expand Down
7 changes: 6 additions & 1 deletion sdks/python/pmxt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""

from .client import Exchange
from ._exchanges import Polymarket, Limitless, Kalshi, KalshiDemo, Probable, Baozi, Myriad, Opinion, Metaculus, Smarkets, PolymarketUS, Hyperliquid, GeminiTitan, Mock, Router
from ._exchanges import Polymarket, Limitless, Kalshi, KalshiDemo, Probable, Baozi, Myriad, Opinion, Metaculus, Smarkets, PolymarketUS, Polymarket_us, Hyperliquid, GeminiTitan, Mock, Router
from .router import Router
from .server_manager import ServerManager
from .errors import (
Expand Down Expand Up @@ -46,6 +46,7 @@
OrderLevel,
Trade,
UserTrade,
FirehoseEvent,
PaginatedMarketsResult,
PaginatedEventsResult,
Order,
Expand All @@ -60,6 +61,7 @@
EventMatchResult,
PriceComparison,
ArbitrageOpportunity,
SubscribedAddressSnapshot,
MatchRelation,
)

Expand Down Expand Up @@ -144,6 +146,7 @@ def restart_server():
"Metaculus",
"Smarkets",
"PolymarketUS",
"Polymarket_us",
"Hyperliquid",
"GeminiTitan",
"Mock",
Expand Down Expand Up @@ -179,6 +182,7 @@ def restart_server():
"OrderLevel",
"Trade",
"UserTrade",
"FirehoseEvent",
"PaginatedMarketsResult",
"PaginatedEventsResult",
"Order",
Expand All @@ -189,5 +193,6 @@ def restart_server():
"EventMatchResult",
"PriceComparison",
"ArbitrageOpportunity",
"SubscribedAddressSnapshot",
"MatchRelation",
]
3 changes: 3 additions & 0 deletions sdks/python/pmxt/_exchanges.py
Original file line number Diff line number Diff line change
Expand Up @@ -513,3 +513,6 @@ def __init__(
auto_start_server=auto_start_server,
pmxt_api_key=pmxt_api_key,
)

# Backwards-compatible aliases for exchange classes generated before underscore handling.
Polymarket_us = PolymarketUS
69 changes: 69 additions & 0 deletions sdks/python/tests/test_public_exports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
import ast
from pathlib import Path


def test_websocket_return_types_are_public_exports():
init_path = Path(__file__).resolve().parents[1] / "pmxt" / "__init__.py"
tree = ast.parse(init_path.read_text(encoding="utf-8"))

imported_models = set()
public_exports = set()

for node in tree.body:
if isinstance(node, ast.ImportFrom) and node.module == "models":
imported_models.update(alias.name for alias in node.names)
elif (
isinstance(node, ast.Assign)
and len(node.targets) == 1
and isinstance(node.targets[0], ast.Name)
and node.targets[0].id == "__all__"
and isinstance(node.value, ast.List)
):
public_exports.update(
item.value
for item in node.value.elts
if isinstance(item, ast.Constant) and isinstance(item.value, str)
)

expected = {"FirehoseEvent", "SubscribedAddressSnapshot"}
assert expected <= imported_models
assert expected <= public_exports


def test_legacy_polymarket_us_alias_stays_public():
init_path = Path(__file__).resolve().parents[1] / "pmxt" / "__init__.py"
exchanges_path = Path(__file__).resolve().parents[1] / "pmxt" / "_exchanges.py"

init_tree = ast.parse(init_path.read_text(encoding="utf-8"))
exchange_imports = set()
public_exports = set()

for node in init_tree.body:
if isinstance(node, ast.ImportFrom) and node.module == "_exchanges":
exchange_imports.update(alias.name for alias in node.names)
elif (
isinstance(node, ast.Assign)
and len(node.targets) == 1
and isinstance(node.targets[0], ast.Name)
and node.targets[0].id == "__all__"
and isinstance(node.value, ast.List)
):
public_exports.update(
item.value
for item in node.value.elts
if isinstance(item, ast.Constant) and isinstance(item.value, str)
)

exchanges_tree = ast.parse(exchanges_path.read_text(encoding="utf-8"))
aliases = {
node.targets[0].id: node.value.id
for node in exchanges_tree.body
if isinstance(node, ast.Assign)
and len(node.targets) == 1
and isinstance(node.targets[0], ast.Name)
and isinstance(node.value, ast.Name)
}

assert "Polymarket_us" in exchange_imports
assert "Polymarket_us" in public_exports
assert aliases["Polymarket_us"] == "PolymarketUS"
Loading