diff --git a/src/main/kotlin/Main.kt b/src/main/kotlin/Main.kt index 32730fd..55d64d7 100644 --- a/src/main/kotlin/Main.kt +++ b/src/main/kotlin/Main.kt @@ -17,12 +17,14 @@ import slang.parser.File2ParseTreeTransformer import slang.repl.Repl import slang.runtime.ConcreteState import slang.runtime.Interpreter +import slang.typeinfer.typeCheck import java.io.File import java.nio.file.Paths class SlangCLI : CliktCommand(name = "slang") { private val filename by argument(help = "Slang source file to execute").optional() private val hlir by option("--hlir", help = "Output HLIR representation instead of running").flag() + private val typecheckOnly by option("--typecheck", help = "Run Hindley-Milner type inference and report errors").flag() private val output by option("-o", help = "Output file for HLIR (default: stdout)") init { @@ -54,6 +56,8 @@ class SlangCLI : CliktCommand(name = "slang") { onSuccess = { programUnit -> if (hlir) { outputHlir(programUnit) + } else if (typecheckOnly) { + runTypeCheck(programUnit) } else { runProgram(programUnit) } @@ -84,6 +88,18 @@ class SlangCLI : CliktCommand(name = "slang") { } } + private fun runTypeCheck(programUnit: ProgramUnit) { + val errors = typeCheck(programUnit) + if (errors.isEmpty()) { + echo("Type checking passed. No errors found.") + } else { + echo("Type checking failed:", err = true) + for (e in errors) { + echo(" ${e.location} ${e.message}", err = true) + } + } + } + private fun runProgram(programUnit: ProgramUnit) { try { val interpreter = Interpreter() diff --git a/src/main/kotlin/slang/typeinfer/TypeInference.kt b/src/main/kotlin/slang/typeinfer/TypeInference.kt new file mode 100644 index 0000000..cc20127 --- /dev/null +++ b/src/main/kotlin/slang/typeinfer/TypeInference.kt @@ -0,0 +1,420 @@ +package slang.typeinfer + +import slang.common.CodeInfo +import slang.common.Result +import slang.hlir.Expr +import slang.hlir.Operator +import slang.hlir.ProgramUnit +import slang.hlir.Stmt + +/** + * Typing environment: maps variable names to polymorphic type schemes. + * Immutable — extending returns a new environment. + */ +class TypeEnv( + private val bindings: Map = emptyMap(), +) { + operator fun get(name: String): TypeScheme? = bindings[name] + + fun extend( + name: String, + scheme: TypeScheme, + ): TypeEnv = TypeEnv(bindings + (name to scheme)) + + fun extend(pairs: List>): TypeEnv = TypeEnv(bindings + pairs) + + /** Free type-variable ids across all schemes in the environment. */ + fun freeVars(): Set = bindings.values.flatMap { freeVars(it.type) - it.vars }.toSet() +} + +/** + * Hindley-Milner type inference engine for Slang HLIR. + * + * Implements Algorithm W with let-polymorphism. + */ +class HindleyMilnerInference { + private var nextId = 0 + private val errors = mutableListOf() + + fun fresh(): SlangType.TVar = SlangType.TVar(nextId++) + + // ---- Generalize / Instantiate ---- + + /** Generalize a type w.r.t. variables NOT free in the environment. */ + fun generalize( + env: TypeEnv, + t: SlangType, + ): TypeScheme { + val envFree = env.freeVars() + val typeVars = freeVars(t) - envFree + return TypeScheme(typeVars, t) + } + + /** Instantiate a type scheme with fresh variables. */ + fun instantiate(scheme: TypeScheme): SlangType { + val subst = scheme.vars.associateWith { fresh() as SlangType } + return applySubst(subst, scheme.type) + } + + private fun applySubst( + subst: Map, + t: SlangType, + ): SlangType { + val p = prune(t) + return when (p) { + is SlangType.TVar -> subst[p.id] ?: p + is SlangType.TFun -> SlangType.TFun(p.params.map { applySubst(subst, it) }, applySubst(subst, p.ret)) + is SlangType.TArray -> SlangType.TArray(applySubst(subst, p.elem)) + is SlangType.TRef -> SlangType.TRef(applySubst(subst, p.inner)) + is SlangType.TRecord -> SlangType.TRecord(p.fields.mapValues { applySubst(subst, it.value) }) + else -> p + } + } + + // ---- Public entry point ---- + + fun inferProgram(program: ProgramUnit): List { + errors.clear() + var env = TypeEnv() + for (module in program.stmt) { + // First pass: register all top-level functions (excluding __module__main__) + val moduleMain = module.functions.find { it.name == "__module__main__" } + for (fn in module.functions) { + if (fn !== moduleMain) { + env = inferFunctionDecl(fn, env) + } + } + // Second pass: infer the module main body + if (moduleMain != null) { + inferBlock(moduleMain.body, env) + } + } + return errors + } + + // ---- Functions ---- + + private fun inferFunctionDecl( + fn: Stmt.Function, + env: TypeEnv, + ): TypeEnv { + val paramTypes = fn.params.map { fresh() as SlangType } + val retType: SlangType = fresh() + val funType = SlangType.TFun(paramTypes, retType) + + // Extend env with the function name (monomorphic, for recursion) + val innerEnv = + env + .extend(fn.name, TypeScheme(emptySet(), funType)) + .extend(fn.params.zip(paramTypes).map { (n, t) -> n to TypeScheme(emptySet(), t) }) + + val bodyType = inferBlock(fn.body, innerEnv) + safeUnify(retType, bodyType, fn.codeInfo) + + // Generalize in the outer env + val scheme = generalize(env, funType) + return env.extend(fn.name, scheme) + } + + // ---- Statements ---- + + /** + * Infer types for a block and return the type of the last expression / return. + * For blocks that produce no value, returns TUnit. + */ + private fun inferBlock( + block: Stmt.BlockStmt, + env: TypeEnv, + ): SlangType { + var currentEnv = env + var resultType: SlangType = SlangType.TUnit + for (stmt in block.stmts) { + val (newEnv, ty) = inferStmt(stmt, currentEnv) + currentEnv = newEnv + resultType = ty + } + return resultType + } + + /** + * Returns (possibly extended env, type produced by this statement). + * For statements that don't produce a value the type is TUnit. + * For return statements the type is the type of the returned expression. + */ + private fun inferStmt( + stmt: Stmt, + env: TypeEnv, + ): Pair = + when (stmt) { + is Stmt.LetStmt -> { + val exprType = inferExpr(stmt.expr, env) + val scheme = generalize(env, exprType) + Pair(env.extend(stmt.name, scheme), SlangType.TUnit) + } + + is Stmt.AssignStmt -> { + val lhsType = inferExpr(stmt.lhs, env) + val rhsType = inferExpr(stmt.expr, env) + safeUnify(lhsType, rhsType, stmt.codeInfo) + Pair(env, SlangType.TUnit) + } + + is Stmt.ExprStmt -> { + val t = inferExpr(stmt.expr, env) + Pair(env, t) + } + + is Stmt.ReturnStmt -> { + val t = inferExpr(stmt.expr, env) + Pair(env, t) + } + + is Stmt.PrintStmt -> { + stmt.args.forEach { inferExpr(it, env) } + Pair(env, SlangType.TUnit) + } + + is Stmt.IfStmt -> { + val condType = inferExpr(stmt.condition, env) + safeUnify(condType, SlangType.TBool, stmt.codeInfo) + val thenType = inferBlock(stmt.thenBody, env) + val elseType = inferBlock(stmt.elseBody, env) + safeUnify(thenType, elseType, stmt.codeInfo) + Pair(env, thenType) + } + + is Stmt.WhileStmt -> { + val condType = inferExpr(stmt.condition, env) + safeUnify(condType, SlangType.TBool, stmt.codeInfo) + inferBlock(stmt.body, env) + Pair(env, SlangType.TUnit) + } + + is Stmt.BlockStmt -> { + val t = inferBlock(stmt, env) + Pair(env, t) + } + + is Stmt.Function -> { + val newEnv = inferFunctionDecl(stmt, env) + Pair(newEnv, SlangType.TUnit) + } + + is Stmt.DerefStmt -> { + val refType = inferExpr(stmt.lhs, env) + val valType = inferExpr(stmt.rhs, env) + val inner: SlangType = fresh() + safeUnify(refType, SlangType.TRef(inner), stmt.codeInfo) + safeUnify(inner, valType, stmt.codeInfo) + Pair(env, SlangType.TUnit) + } + + is Stmt.StructStmt -> { + val fieldTypes = stmt.fields.mapValues { inferExpr(it.value, env) } + val recordType = SlangType.TRecord(fieldTypes) + val scheme = generalize(env, recordType) + Pair(env.extend(stmt.id, scheme), SlangType.TUnit) + } + + is Stmt.Break -> Pair(env, SlangType.TUnit) + is Stmt.Continue -> Pair(env, SlangType.TUnit) + } + + // ---- Expressions ---- + + private fun inferExpr( + expr: Expr, + env: TypeEnv, + ): SlangType = + when (expr) { + is Expr.NumberLiteral -> SlangType.TNum + is Expr.BoolLiteral -> SlangType.TBool + is Expr.StringLiteral -> SlangType.TString + is Expr.NoneValue -> SlangType.TNone + + is Expr.VarExpr -> { + val scheme = env[expr.name] + if (scheme != null) { + instantiate(scheme) + } else { + errors.add(TypeError(expr.codeInfo, "Undefined variable: ${expr.name}")) + fresh() + } + } + + is Expr.ReadInputExpr -> fresh() // could be Num or String + + is Expr.BinaryExpr -> inferBinaryExpr(expr, env) + + is Expr.IfExpr -> { + val condType = inferExpr(expr.condition, env) + safeUnify(condType, SlangType.TBool, expr.codeInfo) + val thenType = inferExpr(expr.thenExpr, env) + val elseType = inferExpr(expr.elseExpr, env) + safeUnify(thenType, elseType, expr.codeInfo) + thenType + } + + is Expr.ParenExpr -> inferExpr(expr.expr, env) + + is Expr.InlinedFunction -> { + val paramTypes = expr.params.map { fresh() as SlangType } + val innerEnv = + env.extend( + expr.params.zip(paramTypes).map { (n, t) -> n to TypeScheme(emptySet(), t) }, + ) + val bodyType = inferBlock(expr.body, innerEnv) + SlangType.TFun(paramTypes, bodyType) + } + + is Expr.NamedFunctionCall -> { + val funType = env[expr.name] + if (funType == null) { + errors.add(TypeError(expr.codeInfo, "Undefined function: ${expr.name}")) + fresh() + } else { + val instType = instantiate(funType) + val argTypes = expr.arguments.map { inferExpr(it, env) } + val retType: SlangType = fresh() + safeUnify(instType, SlangType.TFun(argTypes, retType), expr.codeInfo) + retType + } + } + + is Expr.ExpressionFunctionCall -> { + val targetType = inferExpr(expr.target, env) + val argTypes = expr.arguments.map { inferExpr(it, env) } + val retType: SlangType = fresh() + safeUnify(targetType, SlangType.TFun(argTypes, retType), expr.codeInfo) + retType + } + + is Expr.ArrayInit -> { + val elemType: SlangType = fresh() + for (el in expr.elements) { + val t = inferExpr(el, env) + safeUnify(elemType, t, expr.codeInfo) + } + SlangType.TArray(elemType) + } + + is Expr.ArrayAccess -> { + val arrType = inferExpr(expr.array, env) + val idxType = inferExpr(expr.index, env) + val elemType: SlangType = fresh() + safeUnify(arrType, SlangType.TArray(elemType), expr.codeInfo) + safeUnify(idxType, SlangType.TNum, expr.codeInfo) + elemType + } + + is Expr.Record -> { + val fieldTypes = expr.expression.associate { (name, e) -> name to inferExpr(e, env) } + SlangType.TRecord(fieldTypes) + } + + is Expr.FieldAccess -> { + val recordType = inferExpr(expr.lhs, env) + // We know rhs is always VarExpr (from the AST builder) + val fieldName = (expr.rhs as Expr.VarExpr).name + val fieldType: SlangType = fresh() + // For records, we need structural access; try to unify if already a record + val pruned = prune(recordType) + if (pruned is SlangType.TRecord) { + val ft = pruned.fields[fieldName] + if (ft != null) { + safeUnify(fieldType, ft, expr.codeInfo) + } else { + errors.add(TypeError(expr.codeInfo, "Record has no field '$fieldName'")) + } + } + // If it's a type variable, we can't know the fields yet — return fresh + fieldType + } + + is Expr.RefExpr -> { + val innerType = inferExpr(expr.expr, env) + SlangType.TRef(innerType) + } + + is Expr.DerefExpr -> { + val refType = inferExpr(expr.expr, env) + val innerType: SlangType = fresh() + safeUnify(refType, SlangType.TRef(innerType), expr.codeInfo) + innerType + } + } + + // ---- Binary expressions ---- + + private fun inferBinaryExpr( + expr: Expr.BinaryExpr, + env: TypeEnv, + ): SlangType { + val leftType = inferExpr(expr.left, env) + val rightType = inferExpr(expr.right, env) + + return when (expr.op) { + Operator.PLUS -> { + // PLUS works on Num+Num or String+String; default to unifying both sides + val resultType: SlangType = fresh() + safeUnify(leftType, resultType, expr.codeInfo) + safeUnify(rightType, resultType, expr.codeInfo) + resultType + } + Operator.MINUS, Operator.TIMES, Operator.DIV, Operator.MOD -> { + safeUnify(leftType, SlangType.TNum, expr.codeInfo) + safeUnify(rightType, SlangType.TNum, expr.codeInfo) + SlangType.TNum + } + Operator.LT, Operator.GT, Operator.LEQ, Operator.GEQ -> { + safeUnify(leftType, SlangType.TNum, expr.codeInfo) + safeUnify(rightType, SlangType.TNum, expr.codeInfo) + SlangType.TBool + } + Operator.EQ, Operator.NEQ -> { + safeUnify(leftType, rightType, expr.codeInfo) + SlangType.TBool + } + Operator.AND, Operator.OR -> { + safeUnify(leftType, SlangType.TBool, expr.codeInfo) + safeUnify(rightType, SlangType.TBool, expr.codeInfo) + SlangType.TBool + } + } + } + + // ---- Helpers ---- + + /** Unify with error collection rather than throwing. */ + private fun safeUnify( + a: SlangType, + b: SlangType, + location: CodeInfo, + ) { + try { + unify(a, b, location) + } catch (e: TypeError) { + errors.add(e) + } + } +} + +/** + * Convenience function: run type inference on a program, returning errors. + */ +fun typeCheck(program: ProgramUnit): List = HindleyMilnerInference().inferProgram(program) + +/** + * Pipeline-compatible transform that runs type inference. + * Passes the ProgramUnit through unchanged on success. + */ +class TypeCheckTransform : slang.common.Transform { + override fun transform(input: ProgramUnit): Result> { + val errors = typeCheck(input) + return if (errors.isEmpty()) { + Result.ok(input) + } else { + Result.err(errors.map { slang.parser.CompilerError(it.location, it.message ?: "Type error") }) + } + } +} diff --git a/src/main/kotlin/slang/typeinfer/Types.kt b/src/main/kotlin/slang/typeinfer/Types.kt new file mode 100644 index 0000000..e0cd1ee --- /dev/null +++ b/src/main/kotlin/slang/typeinfer/Types.kt @@ -0,0 +1,95 @@ +package slang.typeinfer + +/** + * Type representation for Hindley-Milner type inference. + * + * Type variables use a mutable union-find structure: each [TVar] has a nullable + * [TVar.bound] field. When bound is null the variable is free; when non-null it + * points to the type it has been unified with. + */ +sealed class SlangType { + /** A type variable (possibly bound via unification). */ + class TVar( + val id: Int, + var bound: SlangType? = null, + ) : SlangType() { + override fun toString(): String = if (bound != null) bound!!.toString() else "t$id" + } + + object TNum : SlangType() { + override fun toString() = "Num" + } + + object TBool : SlangType() { + override fun toString() = "Bool" + } + + object TString : SlangType() { + override fun toString() = "String" + } + + object TNone : SlangType() { + override fun toString() = "None" + } + + object TUnit : SlangType() { + override fun toString() = "Unit" + } + + /** Function type: (param1, param2, ...) -> ret */ + data class TFun( + val params: List, + val ret: SlangType, + ) : SlangType() { + override fun toString() = "(${params.joinToString(", ")}) -> $ret" + } + + data class TArray( + val elem: SlangType, + ) : SlangType() { + override fun toString() = "[$elem]" + } + + data class TRecord( + val fields: Map, + ) : SlangType() { + override fun toString() = "{${fields.entries.joinToString(", ") { "${it.key}: ${it.value}" }}}" + } + + data class TRef( + val inner: SlangType, + ) : SlangType() { + override fun toString() = "Ref<$inner>" + } +} + +/** A polymorphic type scheme: ∀ vars . type */ +data class TypeScheme( + val vars: Set, + val type: SlangType, +) + +/** Resolve a chain of TVar bindings to the root representative type. */ +fun prune(t: SlangType): SlangType = + when { + t is SlangType.TVar && t.bound != null -> { + val pruned = prune(t.bound!!) + t.bound = pruned // path compression + pruned + } + else -> t + } + +/** Collect free type-variable ids in a type. */ +fun freeVars(t: SlangType): Set = + when (val p = prune(t)) { + is SlangType.TVar -> setOf(p.id) + is SlangType.TFun -> p.params.flatMap { freeVars(it) }.toSet() + freeVars(p.ret) + is SlangType.TArray -> freeVars(p.elem) + is SlangType.TRecord -> + p.fields.values + .flatMap { freeVars(it) } + .toSet() + is SlangType.TRef -> freeVars(p.inner) + else -> emptySet() + } diff --git a/src/main/kotlin/slang/typeinfer/Unification.kt b/src/main/kotlin/slang/typeinfer/Unification.kt new file mode 100644 index 0000000..ef06722 --- /dev/null +++ b/src/main/kotlin/slang/typeinfer/Unification.kt @@ -0,0 +1,86 @@ +package slang.typeinfer + +import slang.common.CodeInfo + +/** + * Unification errors thrown when two types cannot be made equal. + */ +class TypeError( + val location: CodeInfo, + message: String, +) : Exception(message) + +/** + * Unifies two types, mutating [SlangType.TVar.bound] fields. + * Throws [TypeError] if the types are incompatible. + */ +fun unify( + a: SlangType, + b: SlangType, + location: CodeInfo = CodeInfo.generic, +) { + val pa = prune(a) + val pb = prune(b) + + if (pa === pb) return + + when { + pa is SlangType.TVar -> { + if (occursIn(pa.id, pb)) { + throw TypeError(location, "Infinite type: t${pa.id} occurs in $pb") + } + pa.bound = pb + } + pb is SlangType.TVar -> unify(pb, pa, location) + + pa is SlangType.TFun && pb is SlangType.TFun -> { + if (pa.params.size != pb.params.size) { + throw TypeError( + location, + "Function arity mismatch: expected ${pa.params.size} params, got ${pb.params.size}", + ) + } + pa.params.zip(pb.params).forEach { (p1, p2) -> unify(p1, p2, location) } + unify(pa.ret, pb.ret, location) + } + + pa is SlangType.TArray && pb is SlangType.TArray -> + unify(pa.elem, pb.elem, location) + + pa is SlangType.TRef && pb is SlangType.TRef -> + unify(pa.inner, pb.inner, location) + + pa is SlangType.TRecord && pb is SlangType.TRecord -> { + if (pa.fields.keys != pb.fields.keys) { + throw TypeError( + location, + "Record field mismatch: ${pa.fields.keys} vs ${pb.fields.keys}", + ) + } + for (key in pa.fields.keys) { + unify(pa.fields[key]!!, pb.fields[key]!!, location) + } + } + + // Ground types must be identical + pa == pb -> { /* ok */ } + + else -> throw TypeError(location, "Cannot unify $pa with $pb") + } +} + +/** Occurs check: does type variable [varId] appear free in [type]? */ +private fun occursIn( + varId: Int, + type: SlangType, +): Boolean { + val p = prune(type) + return when (p) { + is SlangType.TVar -> p.id == varId + is SlangType.TFun -> p.params.any { occursIn(varId, it) } || occursIn(varId, p.ret) + is SlangType.TArray -> occursIn(varId, p.elem) + is SlangType.TRef -> occursIn(varId, p.inner) + is SlangType.TRecord -> p.fields.values.any { occursIn(varId, it) } + else -> false + } +} diff --git a/src/test/kotlin/TypeInferenceTest.kt b/src/test/kotlin/TypeInferenceTest.kt new file mode 100644 index 0000000..fe3f7c8 --- /dev/null +++ b/src/test/kotlin/TypeInferenceTest.kt @@ -0,0 +1,300 @@ +import org.junit.jupiter.api.Assertions.assertEquals +import org.junit.jupiter.api.Assertions.assertTrue +import org.junit.jupiter.api.Test +import slang.hlir.string2hlir +import slang.typeinfer.typeCheck + +class TypeInferenceTest { + private fun assertNoErrors(code: String) { + val ast = + string2hlir(code).fold( + onSuccess = { it }, + onFailure = { throw AssertionError("Parse error: $it") }, + ) + val errors = typeCheck(ast) + assertTrue(errors.isEmpty(), "Expected no type errors but got:\n${errors.joinToString("\n") { it.message ?: "" }}") + } + + private fun assertHasError( + code: String, + expectedFragment: String, + ) { + val ast = + string2hlir(code).fold( + onSuccess = { it }, + onFailure = { throw AssertionError("Parse error: $it") }, + ) + val errors = typeCheck(ast) + assertTrue(errors.isNotEmpty(), "Expected type errors but got none") + assertTrue( + errors.any { (it.message ?: "").contains(expectedFragment) }, + "Expected error containing '$expectedFragment' but got:\n${errors.joinToString("\n") { it.message ?: "" }}", + ) + } + + private fun assertErrorCount( + code: String, + count: Int, + ) { + val ast = + string2hlir(code).fold( + onSuccess = { it }, + onFailure = { throw AssertionError("Parse error: $it") }, + ) + val errors = typeCheck(ast) + assertEquals( + count, + errors.size, + "Expected $count error(s) but got ${errors.size}:\n${errors.joinToString("\n") { it.message ?: "" }}", + ) + } + + // ---- Well-typed programs ---- + + @Test + fun testSimpleArithmetic() { + assertNoErrors( + """ + let x = 10; + let y = 5; + let sum = x + y; + let product = x * y; + print(sum); + print(product); + """.trimIndent(), + ) + } + + @Test + fun testFactorial() { + assertNoErrors( + """ + fun factorial(n) => if (n == 0) then 1 else n * factorial(n - 1); + let fact = factorial(5); + print(fact); + """.trimIndent(), + ) + } + + @Test + fun testBooleanOps() { + assertNoErrors( + """ + let a = true; + let b = false; + let c = a && b; + let d = a || b; + """.trimIndent(), + ) + } + + @Test + fun testHigherOrderFunction() { + assertNoErrors( + """ + fun apply(f, x) => f(x); + fun double(x) => x * 2; + let result = apply(double, 5); + print(result); + """.trimIndent(), + ) + } + + @Test + fun testAnonymousFunction() { + assertNoErrors( + """ + let add = fun(a, b) => a + b; + let result = add(3, 4); + print(result); + """.trimIndent(), + ) + } + + @Test + fun testIfExpression() { + assertNoErrors( + """ + let x = 5; + let result = if (x > 0) then 1 else -1; + print(result); + """.trimIndent(), + ) + } + + @Test + fun testWhileLoop() { + assertNoErrors( + """ + let i = 1; + while (i <= 3) { + print(i); + i = i + 1; + } + """.trimIndent(), + ) + } + + @Test + fun testGCD() { + assertNoErrors( + """ + fun mod(a, b) => if (a < b) then a else mod(a-b, b); + fun gcd(a, b) => if (b == 0) then a else gcd(b, mod(a,b)); + let result = gcd(48, 18); + print(result); + """.trimIndent(), + ) + } + + @Test + fun testStringConcat() { + assertNoErrors( + """ + let greeting = "Hello"; + let name = "World"; + let message = greeting + name; + print(message); + """.trimIndent(), + ) + } + + @Test + fun testLetPolymorphism() { + assertNoErrors( + """ + fun id(x) => x; + let a = id(42); + let b = id(true); + """.trimIndent(), + ) + } + + @Test + fun testIfStatement() { + assertNoErrors( + """ + let a = 10; + let b = 5; + if (a > b) { + print(a); + } else { + print(b); + } + """.trimIndent(), + ) + } + + @Test + fun testImpureFunction() { + assertNoErrors( + """ + fun compute(x) { + let result = x * 2; + return result; + } + let value = compute(5); + print(value); + """.trimIndent(), + ) + } + + // ---- Type errors ---- + + @Test + fun testArithmeticOnBool() { + assertHasError( + """ + let x = true; + let y = x * 2; + """.trimIndent(), + "Cannot unify", + ) + } + + @Test + fun testBooleanOnNumber() { + assertHasError( + """ + let x = 1; + let y = 2; + let z = x && y; + """.trimIndent(), + "Cannot unify", + ) + } + + @Test + fun testIfConditionNotBool() { + assertHasError( + """ + let x = if (42) then 1 else 2; + """.trimIndent(), + "Cannot unify", + ) + } + + @Test + fun testIfBranchMismatch() { + assertHasError( + """ + let x = if (true) then 1 else false; + """.trimIndent(), + "Cannot unify", + ) + } + + @Test + fun testWhileConditionNotBool() { + assertHasError( + """ + while 42 { + print(1); + } + """.trimIndent(), + "Cannot unify", + ) + } + + @Test + fun testArityMismatch() { + assertHasError( + """ + fun f(x) => x + 1; + let y = f(1, 2); + """.trimIndent(), + "arity", + ) + } + + @Test + fun testSubtractionOnBool() { + assertHasError( + """ + let x = true - false; + """.trimIndent(), + "Cannot unify", + ) + } + + @Test + fun testComparisonOnBool() { + assertHasError( + """ + let x = true < false; + """.trimIndent(), + "Cannot unify", + ) + } + + @Test + fun testAssignTypeMismatch() { + assertHasError( + """ + let x = 1; + x = true; + """.trimIndent(), + "Cannot unify", + ) + } +}