From 189c49fae2f51599240d10a670d8036d6316c5eb Mon Sep 17 00:00:00 2001 From: Jyoti Prakash Date: Mon, 10 Nov 2025 22:22:11 +0530 Subject: [PATCH 1/3] module --- src/main/kotlin/slang/hlir/Ast.kt | 36 ++++----- src/main/kotlin/slang/hlir/AstBuilder.kt | 10 ++- .../kotlin/slang/hlir/ControlFlowGraph.kt | 73 ++++++++++++++++--- src/main/kotlin/slang/runtime/Interpreter.kt | 26 ++++++- src/main/kotlin/slang/ui/AstVisualizer.kt | 8 ++ 5 files changed, 117 insertions(+), 36 deletions(-) diff --git a/src/main/kotlin/slang/hlir/Ast.kt b/src/main/kotlin/slang/hlir/Ast.kt index 3353dd4..c3bf012 100644 --- a/src/main/kotlin/slang/hlir/Ast.kt +++ b/src/main/kotlin/slang/hlir/Ast.kt @@ -1,9 +1,5 @@ package slang.hlir -import SlangLexer -import SlangParser -import org.antlr.v4.runtime.ANTLRInputStream -import org.antlr.v4.runtime.CommonTokenStream import slang.common.CodeInfo import slang.common.CodeInfo.Companion.generic @@ -66,9 +62,11 @@ enum class Operator { } data class ProgramUnit( - val stmt: List, + val stmt: List, ) : SlastNode() +data class SlangModule(val functions : List, val inlinedFuncs : List) : SlastNode() + sealed class Stmt : SlastNode() { data class LetStmt( val name: String, @@ -229,40 +227,38 @@ fun SlastNode.prettyPrint(tabStop: Int = 0): String { is Stmt.Function -> "$indent fun $name(${params.joinToString(", ")}) {\n" + body.prettyPrint(tabStop + 1) + "\n$indent}" is Expr.NoneValue -> "None" is Expr.InlinedFunction -> "$indent inline_fun (${params.joinToString(", ")}) => ${body.prettyPrint()}" - is Stmt.IfStmt -> - "$indent if (${condition.prettyPrint()}) {\n" + thenBody.prettyPrint(tabStop + 1) + "\n$indent} else {\n" + - elseBody.prettyPrint( - tabStop + 1, - ) + "\n$indent}" + is Stmt.IfStmt -> "$indent if (${condition.prettyPrint()}) {\n" + thenBody.prettyPrint(tabStop + 1) + "\n$indent} else {\n" + elseBody.prettyPrint( + tabStop + 1, + ) + "\n$indent}" is Stmt.LetStmt -> "$indent let $name = ${expr.prettyPrint()};" is Stmt.PrintStmt -> "$indent print(${args.joinToString(", ") { it.prettyPrint() }});" is Stmt.ReturnStmt -> "$indent return ${expr.prettyPrint()};" is Stmt.WhileStmt -> "$indent while (${condition.prettyPrint()}) {\n" + body.prettyPrint(tabStop + 1) + "\n$indent}" - is Expr.Record -> - "$indent{\n" + expression.joinToString("\n") { "$indent ${it.first} : ${it.second.prettyPrint()}" } + - "\n$indent}" + is Expr.Record -> "$indent{\n" + expression.joinToString("\n") { "$indent ${it.first} : ${it.second.prettyPrint()}" } + "\n$indent}" is Expr.StringLiteral -> "\"$value\"" is Expr.DerefExpr -> "deref(${expr.prettyPrint()})" is Expr.RefExpr -> "ref(${expr.prettyPrint()})" is Stmt.DerefStmt -> "$indent deref(${lhs.prettyPrint()}) = ${rhs.prettyPrint()};" is Expr.FieldAccess -> "${lhs.prettyPrint()}.${rhs.prettyPrint()}" - is Stmt.StructStmt -> - "$indent struct $id {\n" + functions.joinToString("\n") { it.prettyPrint(tabStop + 1) } + "\n$indent}" + - fields.entries.joinToString( - "\n", - ) { "$indent ${it.key} : ${it.value.prettyPrint()}" } + is Stmt.StructStmt -> "$indent struct $id {\n" + functions.joinToString("\n") { it.prettyPrint(tabStop + 1) } + "\n$indent}" + fields.entries.joinToString( + "\n", + ) { "$indent ${it.key} : ${it.value.prettyPrint()}" } is Expr.ArrayAccess -> "${array.prettyPrint()}[${index.prettyPrint()}]" is Expr.ArrayInit -> "$indent [${elements.joinToString(", ") { it.prettyPrint() }}]" is Stmt.Break -> "$indent break;" is Stmt.Continue -> "$indent continue;" + is SlangModule -> { + val funcsStr = functions.joinToString("\n") { it.prettyPrint(tabStop) } + val inlinedStr = inlinedFuncs.joinToString("\n") { it.prettyPrint(tabStop) } + listOf(funcsStr, inlinedStr).filter { it.isNotBlank() }.joinToString("\n\n") + } } } fun main() { - val inputCode = - """ + val inputCode = """ fun power(base, exp) { if (base > exp) { base = base + 1; diff --git a/src/main/kotlin/slang/hlir/AstBuilder.kt b/src/main/kotlin/slang/hlir/AstBuilder.kt index 7967c14..eb368eb 100644 --- a/src/main/kotlin/slang/hlir/AstBuilder.kt +++ b/src/main/kotlin/slang/hlir/AstBuilder.kt @@ -210,7 +210,15 @@ class SlastBuilder( ProgramUnit(emptyList()) } else { val stmts = ctx.stmt().map { visit(it) as Stmt } - ProgramUnit(stmts) + val moduleMain = Stmt.Function("__module__main__", emptyList(), Stmt.BlockStmt + (stmts.filterNot { + it is Stmt.Function || (it is Stmt.ExprStmt && it.expr is Expr.InlinedFunction) + })) + val topLevelInlineFuncs = stmts.filterIsInstance().map { it.expr } + .filterIsInstance() + val topLevelfuncs = stmts.filterIsInstance() + val slangModule = SlangModule(listOf(moduleMain) + topLevelfuncs, topLevelInlineFuncs ) + ProgramUnit(listOf(slangModule)) } expr.codeInfo = createSourceCodeInfo(ctx) return expr diff --git a/src/main/kotlin/slang/hlir/ControlFlowGraph.kt b/src/main/kotlin/slang/hlir/ControlFlowGraph.kt index 9569f72..44da396 100644 --- a/src/main/kotlin/slang/hlir/ControlFlowGraph.kt +++ b/src/main/kotlin/slang/hlir/ControlFlowGraph.kt @@ -98,17 +98,34 @@ class CFGBuilder { * Build CFG for a program */ fun buildForProgram(program: ProgramUnit): ControlFlowGraph { + // For backward compatibility with tests, return a single CFG representing the + // module-level "__module__main__" function if present. Otherwise, build a + // synthetic entry/exit with any top-level statements. blockIdCounter = 0 allBlocks.clear() - val entry = newBlock() - val exit = newBlock() - - val bodyBlock = buildForStmtList(program.stmt, exit) - addEdge(entry, bodyBlock.entry) - addEdge(bodyBlock.exit, exit) + if (program.stmt.isEmpty()) { + val entry = newBlock() + val exit = newBlock() + return ControlFlowGraph(entry, exit, allBlocks) + } - return ControlFlowGraph(entry, exit, allBlocks) + // Use the first module for program-level CFG + val module = program.stmt[0] + + // Try to find the synthetic module main function created by the IR builder + val moduleMain = module.functions.find { it.name == "__module__main__" } + return if (moduleMain != null) { + buildForFunction(moduleMain) + } else if (module.functions.isNotEmpty()) { + // Fallback: build CFG for the first top-level function + buildForFunction(module.functions[0]) + } else { + // No functions: create empty entry/exit + val entry = newBlock() + val exit = newBlock() + ControlFlowGraph(entry, exit, allBlocks) + } } private data class CFGSegment( @@ -130,13 +147,25 @@ class CFGBuilder { var currentSegment = buildForStmt(stmts[0], exitBlock) val entry = currentSegment.entry + val accumulatedBreaks = mutableListOf() + val accumulatedContinues = mutableListOf() + accumulatedBreaks.addAll(currentSegment.breakTargets) + accumulatedContinues.addAll(currentSegment.continueTargets) + for (i in 1 until stmts.size) { val nextSegment = buildForStmt(stmts[i], exitBlock) + // Normal flow: connect the previous segment's exit to the next segment's entry addEdge(currentSegment.exit, nextSegment.entry) - currentSegment = CFGSegment(entry, nextSegment.exit) + + // Accumulate break/continue targets from subsequent segments; they should not be + // connected into the normal fall-through chain here (they are handled by loops) + accumulatedBreaks.addAll(nextSegment.breakTargets) + accumulatedContinues.addAll(nextSegment.continueTargets) + + currentSegment = CFGSegment(entry, nextSegment.exit, accumulatedBreaks.toList(), accumulatedContinues.toList()) } - return CFGSegment(entry, currentSegment.exit) + return CFGSegment(entry, currentSegment.exit, accumulatedBreaks.toList(), accumulatedContinues.toList()) } private fun buildForStmt( @@ -175,32 +204,52 @@ class CFGBuilder { addEdge(thenSegment.exit, mergeBlock) addEdge(elseSegment.exit, mergeBlock) - CFGSegment(condBlock, mergeBlock) + // combine break/continue targets from both branches and propagate upward + val combinedBreaks = thenSegment.breakTargets + elseSegment.breakTargets + val combinedContinues = thenSegment.continueTargets + elseSegment.continueTargets + + CFGSegment(condBlock, mergeBlock, combinedBreaks, combinedContinues) } is Stmt.WhileStmt -> { val condBlock = newBlock(listOf(stmt)) val mergeBlock = newBlock() + + // Build the loop body with the knowledge that its breaks should target mergeBlock val bodySegment = buildForStmt(stmt.body, exitBlock) + // Normal loop edges: cond -> body, body -> cond, cond -> merge (loop exit) addEdge(condBlock, bodySegment.entry) addEdge(bodySegment.exit, condBlock) addEdge(condBlock, mergeBlock) + // Resolve break targets inside the loop: they should jump to mergeBlock + for (bt in bodySegment.breakTargets) { + addEdge(bt, mergeBlock) + } + + // Resolve continue targets inside the loop: they should jump to the condition block + for (ct in bodySegment.continueTargets) { + addEdge(ct, condBlock) + } + + // Consumed break/continue targets should not propagate beyond this loop CFGSegment(condBlock, mergeBlock) } is Stmt.Break -> { val block = newBlock(listOf(stmt)) // Break statements need special handling - they jump to the loop exit - // For now, we create a dead-end block + // Represent a break by returning the block as a break target. It will be + // wired up by the nearest enclosing loop to jump to the loop exit. CFGSegment(block, block, breakTargets = listOf(block)) } is Stmt.Continue -> { val block = newBlock(listOf(stmt)) // Continue statements need special handling - they jump to the loop condition - // For now, we create a dead-end block + // Represent a continue by returning the block as a continue target. It will be + // wired up by the nearest enclosing loop to jump back to the loop condition. CFGSegment(block, block, continueTargets = listOf(block)) } diff --git a/src/main/kotlin/slang/runtime/Interpreter.kt b/src/main/kotlin/slang/runtime/Interpreter.kt index ff8d5dd..ec024c2 100644 --- a/src/main/kotlin/slang/runtime/Interpreter.kt +++ b/src/main/kotlin/slang/runtime/Interpreter.kt @@ -15,9 +15,29 @@ class Interpreter { program: ProgramUnit, state: ConcreteState = ConcreteState(), ): ConcreteState = - program.stmt.fold(state) { currentState, stmt -> - val (newState, _) = executeStmt(stmt, currentState) - newState + // ProgramUnit now contains a list of SlangModule entries. Each module holds top-level functions + // and a synthetic module main that contains top-level statements. We need to: + // 1) register all top-level functions into the environment + // 2) execute the synthetic module main body so top-level stmts run + program.stmt.fold(state) { currentState, module -> + var s = currentState + + // Register all top-level functions except the synthetic module main + val moduleMain = module.functions.find { it.name == "__module__main__" } + for (f in module.functions) { + if (f !== moduleMain) { + val (newState, _) = executeStmt(f, s) + s = newState + } + } + + // Execute the synthetic module main body (if present) to run top-level statements + if (moduleMain != null) { + val (afterMainState, _) = executeStmt(moduleMain.body, s) + s = afterMainState + } + + s } private fun executeStmt( diff --git a/src/main/kotlin/slang/ui/AstVisualizer.kt b/src/main/kotlin/slang/ui/AstVisualizer.kt index 561c786..6e98bcd 100644 --- a/src/main/kotlin/slang/ui/AstVisualizer.kt +++ b/src/main/kotlin/slang/ui/AstVisualizer.kt @@ -8,6 +8,7 @@ import org.fife.ui.rsyntaxtextarea.TokenMakerFactory import org.fife.ui.rtextarea.RTextScrollPane import slang.hlir.Expr import slang.hlir.ProgramUnit +import slang.hlir.SlangModule import slang.hlir.SlastNode import slang.hlir.Stmt import slang.hlir.prettyPrint @@ -184,6 +185,13 @@ fun SlastNode.toTreeNode(): DefaultMutableTreeNode = ) } } + is SlangModule -> + DefaultMutableTreeNode("Module").apply { + add(DefaultMutableTreeNode("Functions").apply { functions.forEach { add(it.toTreeNode()) } }) + if (inlinedFuncs.isNotEmpty()) { + add(DefaultMutableTreeNode("InlinedFunctions").apply { inlinedFuncs.forEach { add(it.toTreeNode()) } }) + } + } } fun expandAllNodes(tree: JTree) { From 8371ac1d7721ed6b873b0a587bbff1320e00af6f Mon Sep 17 00:00:00 2001 From: Jyoti Prakash Date: Mon, 10 Nov 2025 22:45:17 +0530 Subject: [PATCH 2/3] test cases and snapshots --- build.gradle.kts | 17 +++- scripts/approve_snapshots.sh | 95 +++++-------------- src/test/kotlin/DataflowAnalysisTest.kt | 5 +- src/test/kotlin/IRBuilderTests.kt | 6 +- .../IRBuilderTests.testCase1.approved.txt | 20 +++- .../IRBuilderTests.testCase2.approved.txt | 13 ++- .../IRBuilderTests.testCase3.approved.txt | 16 +++- .../IRBuilderTests.testCase4.approved.txt | 16 +++- .../IRBuilderTests.testCase6.approved.txt | 11 ++- .../IRBuilderTests.testCase7.approved.txt | 18 +++- .../IRBuilderTests.testCase8.approved.txt | 1 - 11 files changed, 134 insertions(+), 84 deletions(-) diff --git a/build.gradle.kts b/build.gradle.kts index 43b5147..5ad9aaf 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -1,3 +1,6 @@ +import java.nio.file.Files +import java.nio.file.Paths + plugins { kotlin("jvm") version "2.1.0" antlr @@ -69,18 +72,22 @@ application { // Dry-run by default; pass -PapproveSnapshotsCommit=true to actually git-mv and commit. tasks.register("approveSnapshots") { group = "verification" - description = "Run scripts/approve_snapshots.sh (dry-run). Use -PapproveSnapshotsCommit=true to commit approvals." + description = "Run scripts/approve_snapshots.sh (dry-run). Use -PtestPath=/path/to/test to commit approvals." - val commit = (project.findProperty("approveSnapshotsCommit") as String?).toBoolean() + val testPath = project.findProperty("testPath") as String? val scriptFile = file("scripts/approve_snapshots.sh") if (!scriptFile.exists()) { throw GradleException("Snapshot approval script not found: ${scriptFile.absolutePath}") } - // Use bash for portability and to support script features. - commandLine = listOf("bash", scriptFile.absolutePath) + if (commit) listOf("--commit") else emptyList() + if (testPath == null) { + throw GradleException("Test path not found}") + } - // Stream output to console so user sees what will happen. + if (Files.notExists(Paths.get(testPath))) { + throw GradleException("Directory does not exist: $testPath") + } + commandLine = listOf("bash", scriptFile.absolutePath, testPath) isIgnoreExitValue = false } diff --git a/scripts/approve_snapshots.sh b/scripts/approve_snapshots.sh index 93fd6b9..2c0cbe8 100755 --- a/scripts/approve_snapshots.sh +++ b/scripts/approve_snapshots.sh @@ -1,83 +1,40 @@ #!/usr/bin/env bash set -euo pipefail -# approve_snapshots.sh -# Find files named *.received.* (ApprovalTests pattern) and optionally rename them to *.approved.* -# Usage: -# ./scripts/approve_snapshots.sh # dry run - prints changes that would happen -# ./scripts/approve_snapshots.sh --commit # perform git mv and commit the changes -# ./scripts/approve_snapshots.sh --help # show help +# Usage: `scripts/approve_snapshots.sh` +# Rename all files in the given directory from '*.received.txt' to '*.approved.txt'. +# Non-recursive. Skips targets that already exist. -SHOW_HELP=0 -DO_COMMIT=0 - -for arg in "$@"; do - case "$arg" in - --commit) DO_COMMIT=1 ;; - --help|-h) SHOW_HELP=1 ;; - *) echo "Unknown arg: $arg" ; exit 1 ;; - esac -done - -if [ "$SHOW_HELP" -eq 1 ]; then - sed -n '1,200p' "$0" - exit 0 +if [ "$#" -ne 1 ]; then + echo "Usage: $0 " >&2 + exit 2 fi -# Find received files: patterns like *.received.txt or *.received.md or *.received.json -RECEIVED_FILES=() -while IFS= read -r -d '' file; do - RECEIVED_FILES+=("$file") -done < <(find . -type f -name "*.received.*" -print0) +dir="$1" -if [ ${#RECEIVED_FILES[@]} -eq 0 ]; then - echo "No received snapshot files found. Nothing to do." - exit 0 +if [ ! -d "$dir" ]; then + echo "Error: '$dir' is not a directory" >&2 + exit 3 fi -echo "Found ${#RECEIVED_FILES[@]} received file(s):" -for f in "${RECEIVED_FILES[@]}"; do - echo " $f" -done - -actions=() -for f in "${RECEIVED_FILES[@]}"; do - # compute approved filename: replace .received. with .approved. - approved="${f//.received./.approved.}" - actions+=("$f -> $approved") -done +found_any=false -echo -if [ "$DO_COMMIT" -eq 0 ]; then - echo "DRY RUN: the following rename operations would be performed:" - for a in "${actions[@]}"; do - echo " $a" - done - echo - echo "Run with --commit to actually perform git mv and commit the changes." - exit 0 -fi +for f in "$dir"/*.received.txt; do + if [ ! -e "$f" ]; then + # No matching files (shell left the pattern unexpanded) + if [ "$found_any" = false ]; then + echo "No '*.received.txt' files found in '$dir'." >&2 + exit 0 + fi + break + fi -# If we get here, perform commit -# Ensure working tree is clean -if [ -n "$(git status --porcelain)" ]; then - echo "Working tree is not clean. Please commit or stash your changes before running with --commit." >&2 - git status --porcelain - exit 1 -fi + found_any=true + [ -f "$f" ] || continue -# Perform git mv operations -for f in "${RECEIVED_FILES[@]}"; do - approved="${f//.received./.approved.}" - echo "git mv '$f' '$approved'" - git mv -- "$f" "$approved" + target="${f%.received.txt}.approved.txt" + mv "$f" "$target" + echo "Renamed: '$f' -> '$target'" done -# Commit -msg="chore(test): approve snapshots ($(date -u +%Y-%m-%dT%H:%M:%SZ))" -git add -A -git commit -m "$msg" - -echo "Committed snapshot approvals." - -exit 0 +exit 0 \ No newline at end of file diff --git a/src/test/kotlin/DataflowAnalysisTest.kt b/src/test/kotlin/DataflowAnalysisTest.kt index ef95026..144d512 100644 --- a/src/test/kotlin/DataflowAnalysisTest.kt +++ b/src/test/kotlin/DataflowAnalysisTest.kt @@ -276,10 +276,11 @@ class DataflowAnalysisTest { assertTrue(result is Result.Ok) val programUnit = (result as Result.Ok).value - val function = programUnit.stmt.filterIsInstance().firstOrNull() + // pick the first user-defined top-level function (skip synthetic module main) + val function = programUnit.stmt.flatMap { it.functions }.firstOrNull { it.name != "__module__main__" } assertNotNull(function) - val cfg = function.buildCFG() + val cfg = function!!.buildCFG() val lvAnalysis = LiveVariablesAnalysis() val lvResult = lvAnalysis.analyze(cfg) diff --git a/src/test/kotlin/IRBuilderTests.kt b/src/test/kotlin/IRBuilderTests.kt index 8ea1bcb..5184702 100644 --- a/src/test/kotlin/IRBuilderTests.kt +++ b/src/test/kotlin/IRBuilderTests.kt @@ -2,6 +2,7 @@ import org.approvaltests.Approvals import slang.common.invoke import slang.common.then import slang.hlir.ParseTree2HlirTrasnformer +import slang.hlir.prettyPrint import slang.parser.File2ParseTreeTransformer import java.io.File import java.net.URL @@ -12,7 +13,10 @@ class IRBuilderTests { fun buildAst(testCase: URL): String { val hlir = (File2ParseTreeTransformer() then ParseTree2HlirTrasnformer()).invoke(File(testCase.toURI())) - return hlir.toString() + return hlir.fold( + {it.prettyPrint(0)}, + {""} + ) } @Test diff --git a/src/test/kotlin/IRBuilderTests.testCase1.approved.txt b/src/test/kotlin/IRBuilderTests.testCase1.approved.txt index d60e624..6623004 100644 --- a/src/test/kotlin/IRBuilderTests.testCase1.approved.txt +++ b/src/test/kotlin/IRBuilderTests.testCase1.approved.txt @@ -1 +1,19 @@ -Ok(value=ProgramUnit(stmt=[Function(name=power, params=[base, exp], body=BlockStmt(stmts=[LetStmt(name=result, expr=NumberLiteral(value=1.0)), WhileStmt(condition=BinaryExpr(left=VarExpr(name=exp), op=>, right=NumberLiteral(value=0.0)), body=BlockStmt(stmts=[AssignStmt(lhs=VarExpr(name=result), expr=BinaryExpr(left=VarExpr(name=result), op=*, right=VarExpr(name=base))), AssignStmt(lhs=VarExpr(name=exp), expr=BinaryExpr(left=VarExpr(name=exp), op=-, right=NumberLiteral(value=1.0)))])), ReturnStmt(expr=VarExpr(name=result))])), LetStmt(name=base, expr=ReadInputExpr), LetStmt(name=exp, expr=ReadInputExpr), PrintStmt(args=[NamedFunctionCall(name=power, arguments=[VarExpr(name=base), VarExpr(name=exp)])])])) \ No newline at end of file + fun __module__main__() { + { + let base = readInput(); + let exp = readInput(); + print(power(base, exp)); + } +} + fun power(base, exp) { + { + let result = 1.0; + while (exp > 0.0) { + { + result = result * base; + exp = exp - 1.0; + } + } + return result; + } +} \ No newline at end of file diff --git a/src/test/kotlin/IRBuilderTests.testCase2.approved.txt b/src/test/kotlin/IRBuilderTests.testCase2.approved.txt index dd44653..7e70fb8 100644 --- a/src/test/kotlin/IRBuilderTests.testCase2.approved.txt +++ b/src/test/kotlin/IRBuilderTests.testCase2.approved.txt @@ -1 +1,12 @@ -Ok(value=ProgramUnit(stmt=[Function(name=factorial, params=[n], body=BlockStmt(stmts=[ReturnStmt(expr=IfExpr(condition=BinaryExpr(left=VarExpr(name=n), op===, right=NumberLiteral(value=0.0)), thenExpr=NumberLiteral(value=1.0), elseExpr=BinaryExpr(left=VarExpr(name=n), op=*, right=NamedFunctionCall(name=factorial, arguments=[BinaryExpr(left=VarExpr(name=n), op=-, right=NumberLiteral(value=1.0))]))))])), LetStmt(name=num, expr=ReadInputExpr), LetStmt(name=fact, expr=NamedFunctionCall(name=factorial, arguments=[VarExpr(name=num)])), PrintStmt(args=[VarExpr(name=fact)])])) \ No newline at end of file + fun __module__main__() { + { + let num = readInput(); + let fact = factorial(num); + print(fact); + } +} + fun factorial(n) { + { + return if (n == 0.0) then 1.0 else n * factorial(n - 1.0); + } +} \ No newline at end of file diff --git a/src/test/kotlin/IRBuilderTests.testCase3.approved.txt b/src/test/kotlin/IRBuilderTests.testCase3.approved.txt index d3caeb4..346feb3 100644 --- a/src/test/kotlin/IRBuilderTests.testCase3.approved.txt +++ b/src/test/kotlin/IRBuilderTests.testCase3.approved.txt @@ -1 +1,15 @@ -Ok(value=ProgramUnit(stmt=[Function(name=mod, params=[a, b], body=BlockStmt(stmts=[ReturnStmt(expr=IfExpr(condition=BinaryExpr(left=VarExpr(name=a), op=<, right=VarExpr(name=b)), thenExpr=VarExpr(name=a), elseExpr=NamedFunctionCall(name=mod, arguments=[BinaryExpr(left=VarExpr(name=a), op=-, right=VarExpr(name=b)), VarExpr(name=b)])))])), Function(name=gcd, params=[a, b], body=BlockStmt(stmts=[ReturnStmt(expr=IfExpr(condition=BinaryExpr(left=VarExpr(name=b), op===, right=NumberLiteral(value=0.0)), thenExpr=VarExpr(name=a), elseExpr=NamedFunctionCall(name=gcd, arguments=[VarExpr(name=b), NamedFunctionCall(name=mod, arguments=[VarExpr(name=a), VarExpr(name=b)])])))]))])) \ No newline at end of file + fun __module__main__() { + { + + } +} + fun mod(a, b) { + { + return if (a < b) then a else mod(a - b, b); + } +} + fun gcd(a, b) { + { + return if (b == 0.0) then a else gcd(b, mod(a, b)); + } +} \ No newline at end of file diff --git a/src/test/kotlin/IRBuilderTests.testCase4.approved.txt b/src/test/kotlin/IRBuilderTests.testCase4.approved.txt index 7ca77da..668d7e8 100644 --- a/src/test/kotlin/IRBuilderTests.testCase4.approved.txt +++ b/src/test/kotlin/IRBuilderTests.testCase4.approved.txt @@ -1 +1,15 @@ -Ok(value=ProgramUnit(stmt=[LetStmt(name=a, expr=ReadInputExpr), LetStmt(name=b, expr=ReadInputExpr), IfStmt(condition=BinaryExpr(left=VarExpr(name=a), op=>, right=VarExpr(name=b)), thenBody=BlockStmt(stmts=[PrintStmt(args=[VarExpr(name=a)])]), elseBody=BlockStmt(stmts=[PrintStmt(args=[VarExpr(name=b)])]))])) \ No newline at end of file + fun __module__main__() { + { + let a = readInput(); + let b = readInput(); + if (a > b) { + { + print(a); + } + } else { + { + print(b); + } + } + } +} \ No newline at end of file diff --git a/src/test/kotlin/IRBuilderTests.testCase6.approved.txt b/src/test/kotlin/IRBuilderTests.testCase6.approved.txt index 471efde..9b14673 100644 --- a/src/test/kotlin/IRBuilderTests.testCase6.approved.txt +++ b/src/test/kotlin/IRBuilderTests.testCase6.approved.txt @@ -1 +1,10 @@ -Ok(value=ProgramUnit(stmt=[LetStmt(name=x, expr=NumberLiteral(value=10.0)), LetStmt(name=y, expr=NumberLiteral(value=5.0)), LetStmt(name=sum, expr=BinaryExpr(left=VarExpr(name=x), op=+, right=VarExpr(name=y))), LetStmt(name=product, expr=BinaryExpr(left=VarExpr(name=x), op=*, right=VarExpr(name=y))), PrintStmt(args=[VarExpr(name=sum)]), PrintStmt(args=[VarExpr(name=product)])])) \ No newline at end of file + fun __module__main__() { + { + let x = 10.0; + let y = 5.0; + let sum = x + y; + let product = x * y; + print(sum); + print(product); + } +} \ No newline at end of file diff --git a/src/test/kotlin/IRBuilderTests.testCase7.approved.txt b/src/test/kotlin/IRBuilderTests.testCase7.approved.txt index af169e4..038d090 100644 --- a/src/test/kotlin/IRBuilderTests.testCase7.approved.txt +++ b/src/test/kotlin/IRBuilderTests.testCase7.approved.txt @@ -1 +1,17 @@ -Ok(value=ProgramUnit(stmt=[LetStmt(name=a, expr=NumberLiteral(value=1.0)), BlockStmt(stmts=[BlockStmt(stmts=[PrintStmt(args=[VarExpr(name=a)]), AssignStmt(lhs=VarExpr(name=a), expr=BinaryExpr(left=VarExpr(name=a), op=+, right=NumberLiteral(value=1.0)))]), WhileStmt(condition=BinaryExpr(left=VarExpr(name=a), op=<, right=NumberLiteral(value=5.0)), body=BlockStmt(stmts=[PrintStmt(args=[VarExpr(name=a)]), AssignStmt(lhs=VarExpr(name=a), expr=BinaryExpr(left=VarExpr(name=a), op=+, right=NumberLiteral(value=1.0)))]))])])) \ No newline at end of file + fun __module__main__() { + { + let a = 1.0; + { + { + print(a); + a = a + 1.0; + } + while (a < 5.0) { + { + print(a); + a = a + 1.0; + } + } + } + } +} \ No newline at end of file diff --git a/src/test/kotlin/IRBuilderTests.testCase8.approved.txt b/src/test/kotlin/IRBuilderTests.testCase8.approved.txt index 7384aae..e69de29 100644 --- a/src/test/kotlin/IRBuilderTests.testCase8.approved.txt +++ b/src/test/kotlin/IRBuilderTests.testCase8.approved.txt @@ -1 +0,0 @@ -Err(error=[[35:4 -- 35:13] - Variable z already declared in this scope, [36:4 -- 36:13] - Variable n already declared in this scope]) \ No newline at end of file From 5493d2022b8774eb1ce12500938edefbd0987be3 Mon Sep 17 00:00:00 2001 From: Jyoti Prakash Date: Mon, 10 Nov 2025 22:58:43 +0530 Subject: [PATCH 3/3] fixed test cases and rebased --- src/main/kotlin/slang/hlir/Ast.kt | 28 ++++++++++++++++-------- src/main/kotlin/slang/hlir/AstBuilder.kt | 23 +++++++++++++------ src/test/kotlin/ControlFlowGraphTest.kt | 4 ++-- src/test/kotlin/IRBuilderTests.kt | 4 ++-- 4 files changed, 39 insertions(+), 20 deletions(-) diff --git a/src/main/kotlin/slang/hlir/Ast.kt b/src/main/kotlin/slang/hlir/Ast.kt index c3bf012..4741ecc 100644 --- a/src/main/kotlin/slang/hlir/Ast.kt +++ b/src/main/kotlin/slang/hlir/Ast.kt @@ -65,7 +65,10 @@ data class ProgramUnit( val stmt: List, ) : SlastNode() -data class SlangModule(val functions : List, val inlinedFuncs : List) : SlastNode() +data class SlangModule( + val functions: List, + val inlinedFuncs: List, +) : SlastNode() sealed class Stmt : SlastNode() { data class LetStmt( @@ -227,23 +230,29 @@ fun SlastNode.prettyPrint(tabStop: Int = 0): String { is Stmt.Function -> "$indent fun $name(${params.joinToString(", ")}) {\n" + body.prettyPrint(tabStop + 1) + "\n$indent}" is Expr.NoneValue -> "None" is Expr.InlinedFunction -> "$indent inline_fun (${params.joinToString(", ")}) => ${body.prettyPrint()}" - is Stmt.IfStmt -> "$indent if (${condition.prettyPrint()}) {\n" + thenBody.prettyPrint(tabStop + 1) + "\n$indent} else {\n" + elseBody.prettyPrint( - tabStop + 1, - ) + "\n$indent}" + is Stmt.IfStmt -> + "$indent if (${condition.prettyPrint()}) {\n" + thenBody.prettyPrint(tabStop + 1) + "\n$indent} else {\n" + + elseBody.prettyPrint( + tabStop + 1, + ) + "\n$indent}" is Stmt.LetStmt -> "$indent let $name = ${expr.prettyPrint()};" is Stmt.PrintStmt -> "$indent print(${args.joinToString(", ") { it.prettyPrint() }});" is Stmt.ReturnStmt -> "$indent return ${expr.prettyPrint()};" is Stmt.WhileStmt -> "$indent while (${condition.prettyPrint()}) {\n" + body.prettyPrint(tabStop + 1) + "\n$indent}" - is Expr.Record -> "$indent{\n" + expression.joinToString("\n") { "$indent ${it.first} : ${it.second.prettyPrint()}" } + "\n$indent}" + is Expr.Record -> + "$indent{\n" + expression.joinToString("\n") { "$indent ${it.first} : ${it.second.prettyPrint()}" } + + "\n$indent}" is Expr.StringLiteral -> "\"$value\"" is Expr.DerefExpr -> "deref(${expr.prettyPrint()})" is Expr.RefExpr -> "ref(${expr.prettyPrint()})" is Stmt.DerefStmt -> "$indent deref(${lhs.prettyPrint()}) = ${rhs.prettyPrint()};" is Expr.FieldAccess -> "${lhs.prettyPrint()}.${rhs.prettyPrint()}" - is Stmt.StructStmt -> "$indent struct $id {\n" + functions.joinToString("\n") { it.prettyPrint(tabStop + 1) } + "\n$indent}" + fields.entries.joinToString( - "\n", - ) { "$indent ${it.key} : ${it.value.prettyPrint()}" } + is Stmt.StructStmt -> + "$indent struct $id {\n" + functions.joinToString("\n") { it.prettyPrint(tabStop + 1) } + "\n$indent}" + + fields.entries.joinToString( + "\n", + ) { "$indent ${it.key} : ${it.value.prettyPrint()}" } is Expr.ArrayAccess -> "${array.prettyPrint()}[${index.prettyPrint()}]" is Expr.ArrayInit -> "$indent [${elements.joinToString(", ") { it.prettyPrint() }}]" @@ -258,7 +267,8 @@ fun SlastNode.prettyPrint(tabStop: Int = 0): String { } fun main() { - val inputCode = """ + val inputCode = + """ fun power(base, exp) { if (base > exp) { base = base + 1; diff --git a/src/main/kotlin/slang/hlir/AstBuilder.kt b/src/main/kotlin/slang/hlir/AstBuilder.kt index eb368eb..ed82e42 100644 --- a/src/main/kotlin/slang/hlir/AstBuilder.kt +++ b/src/main/kotlin/slang/hlir/AstBuilder.kt @@ -210,14 +210,23 @@ class SlastBuilder( ProgramUnit(emptyList()) } else { val stmts = ctx.stmt().map { visit(it) as Stmt } - val moduleMain = Stmt.Function("__module__main__", emptyList(), Stmt.BlockStmt - (stmts.filterNot { - it is Stmt.Function || (it is Stmt.ExprStmt && it.expr is Expr.InlinedFunction) - })) - val topLevelInlineFuncs = stmts.filterIsInstance().map { it.expr } - .filterIsInstance() + val moduleMain = + Stmt.Function( + "__module__main__", + emptyList(), + Stmt.BlockStmt( + stmts.filterNot { + it is Stmt.Function || (it is Stmt.ExprStmt && it.expr is Expr.InlinedFunction) + }, + ), + ) + val topLevelInlineFuncs = + stmts + .filterIsInstance() + .map { it.expr } + .filterIsInstance() val topLevelfuncs = stmts.filterIsInstance() - val slangModule = SlangModule(listOf(moduleMain) + topLevelfuncs, topLevelInlineFuncs ) + val slangModule = SlangModule(listOf(moduleMain) + topLevelfuncs, topLevelInlineFuncs) ProgramUnit(listOf(slangModule)) } expr.codeInfo = createSourceCodeInfo(ctx) diff --git a/src/test/kotlin/ControlFlowGraphTest.kt b/src/test/kotlin/ControlFlowGraphTest.kt index 078e2a6..87ee6be 100644 --- a/src/test/kotlin/ControlFlowGraphTest.kt +++ b/src/test/kotlin/ControlFlowGraphTest.kt @@ -98,8 +98,8 @@ class ControlFlowGraphTest { assertTrue(result is Result.Ok) val programUnit = (result as Result.Ok).value - // Get the function - val function = programUnit.stmt.filterIsInstance().firstOrNull() + // Get the user-defined function (skip synthetic module main) + val function = programUnit.stmt.flatMap { it.functions }.firstOrNull { it.name != "__module__main__" } assertNotNull(function) val cfg = function.buildCFG() diff --git a/src/test/kotlin/IRBuilderTests.kt b/src/test/kotlin/IRBuilderTests.kt index 5184702..db7beb7 100644 --- a/src/test/kotlin/IRBuilderTests.kt +++ b/src/test/kotlin/IRBuilderTests.kt @@ -14,8 +14,8 @@ class IRBuilderTests { fun buildAst(testCase: URL): String { val hlir = (File2ParseTreeTransformer() then ParseTree2HlirTrasnformer()).invoke(File(testCase.toURI())) return hlir.fold( - {it.prettyPrint(0)}, - {""} + { it.prettyPrint(0) }, + { "" }, ) }