|
1 | 1 | /* eslint-disable @typescript-eslint/no-explicit-any */ |
2 | 2 | import * as ts from 'typescript'; |
3 | | - |
4 | | -function createProgramAndGetTypeChecker(context: ts.TransformationContext) { |
5 | | - const compilerOptions = context.getCompilerOptions(); |
6 | | - const rootDir = compilerOptions.rootDir || '.'; |
7 | | - // Create a TypeScript program with the transformed source files |
8 | | - const program = ts.createProgram({ |
9 | | - options: compilerOptions, |
10 | | - rootNames: [rootDir], |
11 | | - }); |
12 | | - |
13 | | - // Get the TypeChecker from the program |
14 | | - const typeChecker = program.getTypeChecker(); |
15 | | - |
16 | | - return { program, typeChecker }; |
17 | | -} |
18 | | - |
19 | | -// Function to find the class where the method belongs |
20 | | -function findClassForMethod( |
21 | | - methodNode: ts.MethodDeclaration, |
22 | | -): ts.ClassDeclaration | undefined { |
23 | | - let parent: ts.Node = methodNode.parent; |
24 | | - while (parent) { |
25 | | - if (ts.isClassDeclaration(parent)) { |
26 | | - return parent; // Found the class |
27 | | - } |
28 | | - parent = parent.parent; // Keep looking up the tree |
29 | | - } |
30 | | - return undefined; // No class found (in case it's not part of a class) |
31 | | -} |
| 3 | +import { |
| 4 | + findClassForMethod, |
| 5 | + createProgramAndGetTypeChecker, |
| 6 | + traverseImportFactoryBuilder, |
| 7 | + MapEx, |
| 8 | +} from './helpers'; |
| 9 | +import { registerReferencedExtensions } from './helpers/register-referenced-extensions'; |
32 | 10 |
|
33 | 11 | export function before() { |
| 12 | + const extensions = new MapEx< |
| 13 | + ts.SourceFile, |
| 14 | + MapEx<ts.Type, MapEx<string, ts.Identifier>> |
| 15 | + >(); |
| 16 | + const sources = new MapEx<string, Set<ts.SourceFile>>(); |
| 17 | + const sourceNameMap = new Map<string, ts.SourceFile>(); |
34 | 18 | return (context: ts.TransformationContext) => { |
35 | | - const { typeChecker } = createProgramAndGetTypeChecker(context); |
36 | | - const extensions = new Map< |
37 | | - ts.SourceFile, |
38 | | - Map<ts.Type, Map<string, ts.Identifier>> |
39 | | - >(); |
40 | | - return (rootNode: ts.SourceFile) => { |
41 | | - const registerExtensions: ts.Visitor = (node: ts.Node): ts.Node => { |
| 19 | + const tsRef = createProgramAndGetTypeChecker(context); |
| 20 | + const { traverseImportFactory } = traverseImportFactoryBuilder( |
| 21 | + extensions, |
| 22 | + sources, |
| 23 | + tsRef, |
| 24 | + ); |
| 25 | + |
| 26 | + return function transformExtensionRefs(rootNode: ts.SourceFile) { |
| 27 | + const { getExtensionCall, traverseImport } = traverseImportFactory( |
| 28 | + rootNode, |
| 29 | + sourceNameMap, |
| 30 | + ); |
| 31 | + |
| 32 | + /** |
| 33 | + * Traverse function to register every extension the |
| 34 | + * declared |
| 35 | + * @param node The node to be analyzed |
| 36 | + */ |
| 37 | + function registerExtensions(node: ts.Node): ts.Node { |
42 | 38 | const visitNext = () => |
43 | 39 | ts.visitEachChild(node, registerExtensions, context); |
| 40 | + if (ts.isImportDeclaration(node) || ts.isExportDeclaration(node)) { |
| 41 | + const visited = visitNext(); |
| 42 | + registerReferencedExtensions(sources, node, rootNode, tsRef); |
| 43 | + return visited; |
| 44 | + } |
44 | 45 | const decorators = ts.canHaveDecorators(node) |
45 | 46 | ? ts.getDecorators(node) |
46 | 47 | : undefined; |
47 | 48 | // Handle method declarations with the @ExtensionMethod decorator |
48 | | - if (!ts.isMethodDeclaration(node) || !decorators?.length) { |
49 | | - return visitNext(); |
50 | | - } |
| 49 | + if (!ts.isMethodDeclaration(node)) return visitNext(); |
51 | 50 | // Check for the @ExtensionMethod decorator |
52 | | - const extensionDecorator = decorators.find( |
| 51 | + const extensionDecorator = decorators?.find( |
53 | 52 | (decorator) => decorator.getText() === '@ExtensionMethod', |
54 | 53 | ); |
| 54 | + const first = node.parameters[0]; |
| 55 | + const type = first |
| 56 | + ? tsRef.typeChecker.getTypeAtLocation(first) |
| 57 | + : undefined; |
| 58 | + const cls = findClassForMethod(node); |
55 | 59 |
|
56 | | - if (!extensionDecorator) return visitNext(); |
57 | | - // Ensure the method is static and has 'this' parameter for the extension type |
58 | 60 | if ( |
| 61 | + !extensionDecorator || |
59 | 62 | !node.parameters.length || |
60 | 63 | !node.modifiers?.some( |
61 | 64 | (mod) => mod.kind === ts.SyntaxKind.StaticKeyword, |
62 | | - ) |
| 65 | + ) || |
| 66 | + !type || |
| 67 | + !cls?.name |
63 | 68 | ) { |
64 | 69 | return visitNext(); |
65 | 70 | } |
66 | | - const first = node.parameters[0]; |
67 | | - if (!first) return visitNext(); |
68 | | - const type = typeChecker.getTypeAtLocation(first); |
69 | | - if (!type) return visitNext(); |
70 | | - let sourceMap = extensions.get(rootNode); |
71 | | - if (!sourceMap) { |
72 | | - sourceMap = new Map(); |
73 | | - extensions.set(rootNode, sourceMap); |
74 | | - } |
75 | | - let extensionMethods = sourceMap.get(type); |
76 | | - if (!extensionMethods) { |
77 | | - extensionMethods = new Map(); |
78 | | - sourceMap.set(type, extensionMethods); |
79 | | - } |
80 | | - const cls = findClassForMethod(node); |
81 | | - if (!cls?.name) return visitNext(); |
82 | | - extensionMethods.set(node.name.getText(), cls.name); |
| 71 | + extensions |
| 72 | + .getOrSet(rootNode, () => new MapEx()) |
| 73 | + .getOrSet(type, () => new MapEx()) |
| 74 | + .set(node.name.getText(), cls.name); |
| 75 | + sources.getOrSet(rootNode.fileName, () => new Set()).add(rootNode); |
| 76 | + sourceNameMap.set(rootNode.fileName, rootNode); |
83 | 77 | return ts.visitEachChild(node, registerExtensions, context); |
84 | | - }; |
| 78 | + } |
85 | 79 |
|
86 | | - const transformExtensions: ts.Visitor = (node: ts.Node): ts.Node => { |
| 80 | + /** |
| 81 | + * Traverse function the replace every extension import |
| 82 | + * and every extension method call to static call reference |
| 83 | + * @param node the node to be analyzed |
| 84 | + */ |
| 85 | + function transformExtensions(node: ts.Node): ts.Node { |
87 | 86 | const visitNext = () => |
88 | 87 | ts.visitEachChild(node, transformExtensions, context); |
| 88 | + if (ts.isImportDeclaration(node)) return traverseImport(node); |
89 | 89 | if (!ts.isCallExpression(node)) return visitNext(); |
90 | | - const { expression, arguments: args } = node; |
91 | | - if (!ts.isPropertyAccessExpression(expression)) return visitNext(); |
92 | | - const targetInstance = expression.expression; |
93 | | - const methodName = expression.name.getText(); |
94 | | - if (!targetInstance || !methodName) return visitNext(); |
95 | | - const extensionList = extensions.get(rootNode); |
96 | | - if (!extensionList) return visitNext(); |
97 | | - const type = typeChecker.getTypeAtLocation(targetInstance); |
98 | | - if (!type) return visitNext(); |
99 | | - let extension = extensionList.get(type)?.get(methodName); |
100 | | - if (!extension) { |
101 | | - for (const [key, value] of extensionList.entries()) { |
102 | | - if (typeChecker.isTypeAssignableTo(type, key)) { |
103 | | - extension = value.get(methodName); |
104 | | - if (extension) break; |
105 | | - } |
106 | | - } |
107 | | - } |
108 | | - if (!extension) return visitNext(); |
109 | | - |
110 | | - // Create the transformed call: MyExtensionClass.myExtensionMethod(myInstance) |
111 | | - return ts.factory.createCallExpression( |
112 | | - ts.factory.createPropertyAccessExpression( |
113 | | - extension, |
114 | | - ts.factory.createIdentifier(methodName), |
115 | | - ), |
116 | | - undefined, |
117 | | - [targetInstance, ...args], |
118 | | - ); |
119 | | - }; |
| 90 | + return getExtensionCall(node) ?? visitNext(); |
| 91 | + } |
120 | 92 |
|
121 | 93 | ts.visitNode(rootNode, registerExtensions); |
122 | 94 |
|
|
0 commit comments