diff --git a/squash-core/src/org/jetbrains/squash/dialect/BaseSQLDialect.kt b/squash-core/src/org/jetbrains/squash/dialect/BaseSQLDialect.kt index 5b8e827..674daa1 100644 --- a/squash-core/src/org/jetbrains/squash/dialect/BaseSQLDialect.kt +++ b/squash-core/src/org/jetbrains/squash/dialect/BaseSQLDialect.kt @@ -145,29 +145,19 @@ open class BaseSQLDialect(val name: String) : SQLDialect { when (expression) { is CountExpression -> { append("COUNT(") - appendExpression(this, expression.value) + appendExpression(this, expression.value!!) append(")") } is CountDistinctExpression -> { append("COUNT(DISTINCT ") - appendExpression(this, expression.value) - append(")") - } - is MaxExpression -> { - append("MAX(") - appendExpression(this, expression.value) - append(")") - } - is MinExpression -> { - append("MIN(") - appendExpression(this, expression.value) - append(")") - } - is SumExpression -> { - append("SUM(") - appendExpression(this, expression.value) + appendExpression(this, expression.value!!) append(")") } + is GeneralFunctionExpression -> { + append("${expression.name}(") + appendExpression(this, expression.value) + append(")") + } else -> error("Function '$expression' is not supported by ${this@BaseSQLDialect}") } } diff --git a/squash-core/src/org/jetbrains/squash/expressions/FunctionExpression.kt b/squash-core/src/org/jetbrains/squash/expressions/FunctionExpression.kt index 858816e..fccdbb8 100644 --- a/squash-core/src/org/jetbrains/squash/expressions/FunctionExpression.kt +++ b/squash-core/src/org/jetbrains/squash/expressions/FunctionExpression.kt @@ -1,16 +1,23 @@ package org.jetbrains.squash.expressions +import java.math.BigDecimal + interface FunctionExpression : Expression -class CountExpression(val value: Expression<*>) : FunctionExpression -class CountDistinctExpression(val value: Expression<*>) : FunctionExpression -class MinExpression(val value: Expression<*>) : FunctionExpression -class MaxExpression(val value: Expression<*>) : FunctionExpression -class SumExpression(val value: Expression<*>) : FunctionExpression +/** + * Represents any function with a name, single argument, and return value. + */ +class GeneralFunctionExpression( + val name:String, + val value:Expression<*> +) : FunctionExpression + +class CountExpression(val value: Expression<*>? = null) : FunctionExpression +class CountDistinctExpression(val value:Expression<*>? = null) : FunctionExpression fun Expression<*>.count() = CountExpression(this) fun Expression<*>.countDistinct() = CountDistinctExpression(this) -fun Expression<*>.min() = MinExpression(this) -fun Expression<*>.max() = MaxExpression(this) -fun Expression<*>.sum() = SumExpression(this) - +fun Expression.min() = GeneralFunctionExpression("MIN",this) +fun Expression.max() = GeneralFunctionExpression("MAX", this) +fun Expression.sum() = GeneralFunctionExpression("SUM",this) +fun Expression<*>.average() = GeneralFunctionExpression("AVG",this) diff --git a/squash-core/test/org/jetbrains/squash/tests/QueryTests.kt b/squash-core/test/org/jetbrains/squash/tests/QueryTests.kt index da2066a..ca2ac2b 100644 --- a/squash-core/test/org/jetbrains/squash/tests/QueryTests.kt +++ b/squash-core/test/org/jetbrains/squash/tests/QueryTests.kt @@ -6,6 +6,8 @@ import org.jetbrains.squash.query.* import org.jetbrains.squash.results.* import org.jetbrains.squash.statements.* import org.jetbrains.squash.tests.data.* +import java.math.BigDecimal +import java.math.BigInteger import kotlin.test.* abstract class QueryTests : DatabaseTests { @@ -448,6 +450,26 @@ abstract class QueryTests : DatabaseTests { } } + @Test fun selectAggregate() { + withCities { + val query = select( + CityStats.value.min().alias("minimum"), + CityStats.value.max().alias("maximum"), + CityStats.value.average().alias("average") + ) + .from(Cities) + .innerJoin(CityStats, + (Cities.id eq CityStats.cityId) + .and(CityStats.name eq "population") + ) + + val result = query.execute().single() + assertEquals(BigDecimal("1500000"), result["minimum"], "Minimum city population does not match") + assertEquals(BigDecimal("6200000"), result["maximum"], "Maximum city population does not match") + assertEquals(BigInteger("3433333"), result.get("average").toBigInteger(), "Average city population does not match") + } + } + @Test fun selectFromNestedQuery() { withCities { val query = from(select(Citizens.name, Citizens.id).from(Citizens).alias("Citizens")) diff --git a/squash-core/test/org/jetbrains/squash/tests/data/CitiesData.kt b/squash-core/test/org/jetbrains/squash/tests/data/CitiesData.kt index 7085875..dc485db 100644 --- a/squash-core/test/org/jetbrains/squash/tests/data/CitiesData.kt +++ b/squash-core/test/org/jetbrains/squash/tests/data/CitiesData.kt @@ -7,7 +7,7 @@ import org.jetbrains.squash.statements.* import org.jetbrains.squash.tests.* fun DatabaseTests.withCities(statement: Transaction.() -> R) :R { - return withTables(Cities, CitizenData, Citizens, CitizenDataLink) { + return withTables(Cities, CityStats, CitizenData, Citizens, CitizenDataLink) { val spbId = insertInto(Cities).values { it[name] = "St. Petersburg" }.fetch(Cities.id).execute() @@ -16,10 +16,35 @@ fun DatabaseTests.withCities(statement: Transaction.() -> R) :R { it[name] = "Munich" }.fetch(Cities.id).execute() - insertInto(Cities).values { + val pragueId = insertInto(Cities).values { it[name] = "Prague" - }.execute() + }.fetch(Cities.id).execute() + + /* + * Insert City Statistics + */ + + insertInto(CityStats).values { + it[cityId] = spbId + it[name] = "population" + it[value] = 6200000 + }.execute() + + insertInto(CityStats).values { + it[cityId] = munichId + it[name] = "population" + it[value] = 1500000 + }.execute() + + insertInto(CityStats).values { + it[cityId] = pragueId + it[name] = "population" + it[value] = 2600000 + }.execute() + /* + * Insert Citizens + */ insertInto(Citizens).query() .select { literal("andrey").alias("id") } .select { literal("Andrey").alias("name") } diff --git a/squash-core/test/org/jetbrains/squash/tests/data/CitiesSchema.kt b/squash-core/test/org/jetbrains/squash/tests/data/CitiesSchema.kt index f5fd009..eea4759 100644 --- a/squash-core/test/org/jetbrains/squash/tests/data/CitiesSchema.kt +++ b/squash-core/test/org/jetbrains/squash/tests/data/CitiesSchema.kt @@ -10,6 +10,12 @@ object Cities : TableDefinition() { val name = varchar("name", 50) } +object CityStats : TableDefinition() { + val cityId = reference(Cities.id, "cityId") + val name = varchar("name", 50) + val value = long("value") +} + object Citizens : TableDefinition() { val id = varchar("id", 10).primaryKey() val name = varchar("name", length = 50) diff --git a/squash-jdbc/src/org/jetbrains/squash/drivers/JDBCDataConversion.kt b/squash-jdbc/src/org/jetbrains/squash/drivers/JDBCDataConversion.kt index 1d91b80..52ec9fc 100644 --- a/squash-jdbc/src/org/jetbrains/squash/drivers/JDBCDataConversion.kt +++ b/squash-jdbc/src/org/jetbrains/squash/drivers/JDBCDataConversion.kt @@ -1,7 +1,8 @@ package org.jetbrains.squash.drivers import org.jetbrains.squash.connection.* -import java.math.* +import java.math.BigDecimal +import java.math.BigInteger import java.sql.* import java.time.* import kotlin.reflect.* @@ -30,11 +31,16 @@ open class JDBCDataConversion { value is Time -> value.toLocalTime() value is Blob -> JDBCBinaryObject(value.getBytes(1, value.length().toInt())) value is ByteArray && type == BinaryObject::class -> JDBCBinaryObject(value) - type.javaObjectType.isInstance(value) -> value + value is Double && type.javaObjectType == BigDecimal::class.java -> value.toBigDecimal() + value is Int && type.javaObjectType == BigInteger::class.java -> value.toBigInteger() + value is Int && type.javaObjectType == BigDecimal::class.java -> value.toBigDecimal() value is Long && type.javaObjectType == Int::class.javaObjectType -> value.toInt() + value is Long && type.javaObjectType == BigInteger::class.java -> value.toBigInteger() + value is Long && type.javaObjectType == BigDecimal::class.java -> value.toBigDecimal() value is Int && type.javaObjectType == Long::class.javaObjectType -> value.toLong() value is BigInteger && type.javaObjectType == Int::class.javaObjectType -> value.toInt() value is BigInteger && type.javaObjectType == Long::class.javaObjectType -> value.toLong() + type.javaObjectType.isInstance(value) -> value else -> error("Cannot convert value of type `${value.javaClass}` to type `$type`") } }