diff --git a/squash-core/src/org/jetbrains/squash/dialect/BaseSQLDialect.kt b/squash-core/src/org/jetbrains/squash/dialect/BaseSQLDialect.kt index 5b8e827..466310d 100644 --- a/squash-core/src/org/jetbrains/squash/dialect/BaseSQLDialect.kt +++ b/squash-core/src/org/jetbrains/squash/dialect/BaseSQLDialect.kt @@ -93,6 +93,9 @@ open class BaseSQLDialect(val name: String) : SQLDialect { is FunctionExpression -> { appendFunctionExpression(this, expression) } + is CaseExpression<*> -> { + appendCaseExpression(this, expression) + } is DialectExtension -> { expression.appendTo(this, this@BaseSQLDialect) } @@ -141,6 +144,30 @@ open class BaseSQLDialect(val name: String) : SQLDialect { }) } + open fun appendCaseExpression(builder: SQLStatementBuilder, expression: CaseExpression) = with(builder) { + if (expression.target == null) { + append("CASE") + } else { + append("CASE (") + appendExpression(this, expression.target) + append(")") + } + + for (whenThenClause in expression.clauses) { + append(" WHEN (") + appendExpression(this, whenThenClause.whenClause) + append(") THEN ") + appendExpression(this, whenThenClause.thenClause) + } + + if (expression.finalClause != null) { + append(" ELSE ") + appendExpression(this, expression.finalClause!!) + } + + append(" END") + } + open fun appendFunctionExpression(builder: SQLStatementBuilder, expression: FunctionExpression) = with(builder) { when (expression) { is CountExpression -> { diff --git a/squash-core/src/org/jetbrains/squash/expressions/CaseExpression.kt b/squash-core/src/org/jetbrains/squash/expressions/CaseExpression.kt new file mode 100644 index 0000000..7a397b7 --- /dev/null +++ b/squash-core/src/org/jetbrains/squash/expressions/CaseExpression.kt @@ -0,0 +1,28 @@ +package org.jetbrains.squash.expressions + +class CaseExpression(val target:Expression<*>? = null) : Expression { + val clauses = mutableListOf>() + var finalClause:Expression? = null + + fun whenClause(expression:Expression<*>):WhenThenClause = WhenThenClause(this, expression).apply { + this@CaseExpression.clauses.add(this) + } + + fun elseClause(expression:Expression):CaseExpression = this.apply { + finalClause = expression + } + + class WhenThenClause( + val caseExpression:CaseExpression, + val whenClause:Expression<*> + ) { + lateinit var thenClause:Expression + + fun thenClause(expression:Expression):CaseExpression { + thenClause = expression + return caseExpression + } + } +} + +inline fun case(target:Expression<*>? = null, clauses:CaseExpression.() -> Unit) = CaseExpression(target).apply(clauses) diff --git a/squash-core/test/org/jetbrains/squash/tests/QueryTests.kt b/squash-core/test/org/jetbrains/squash/tests/QueryTests.kt index da2066a..8f941e6 100644 --- a/squash-core/test/org/jetbrains/squash/tests/QueryTests.kt +++ b/squash-core/test/org/jetbrains/squash/tests/QueryTests.kt @@ -3,8 +3,9 @@ package org.jetbrains.squash.tests import org.jetbrains.squash.definition.* import org.jetbrains.squash.expressions.* import org.jetbrains.squash.query.* -import org.jetbrains.squash.results.* -import org.jetbrains.squash.statements.* +import org.jetbrains.squash.results.get +import org.jetbrains.squash.statements.insertInto +import org.jetbrains.squash.statements.values import org.jetbrains.squash.tests.data.* import kotlin.test.* @@ -12,7 +13,7 @@ abstract class QueryTests : DatabaseTests { open fun nullsLast(sql: String): String = "$sql NULLS LAST" @Test fun selectLiteral() { - withTables() { + withTables { val eugene = literal("eugene") val query = select { eugene } @@ -447,7 +448,43 @@ abstract class QueryTests : DatabaseTests { } } - + + @Test fun selectCaseStatementTrue() { + withTables { + val query = select( + case(literal(5)) { + whenClause(literal(6)).thenClause(literal("false")) + whenClause(literal(5)).thenClause(literal("true")) + whenClause(literal(4)).thenClause(literal("false")) + elseClause(literal("false")) + }) + + connection.dialect.statementSQL(query).assertSQL { + "SELECT CASE (?) WHEN (?) THEN ? WHEN (?) THEN ? WHEN (?) THEN ? ELSE ? END" + } + + assertTrue { query.execute().single().get(0).toBoolean() } + } + } + + @Test fun selectCaseStatementFalse() { + withTables { + val query = select( + case(literal(-1)) { + whenClause(literal(6)).thenClause(literal("false")) + whenClause(literal(5)).thenClause(literal("true")) + whenClause(literal(4)).thenClause(literal("false")) + elseClause(literal("false")) + }) + + connection.dialect.statementSQL(query).assertSQL { + "SELECT CASE (?) WHEN (?) THEN ? WHEN (?) THEN ? WHEN (?) THEN ? ELSE ? END" + } + + assertFalse { query.execute().single().get(0).toBoolean() } + } + } + @Test fun selectFromNestedQuery() { withCities { val query = from(select(Citizens.name, Citizens.id).from(Citizens).alias("Citizens"))