diff --git a/src/main/scala/slick/migration/api/Dialect.scala b/src/main/scala/slick/migration/api/Dialect.scala index 280a0eb4..7c22757c 100644 --- a/src/main/scala/slick/migration/api/Dialect.scala +++ b/src/main/scala/slick/migration/api/Dialect.scala @@ -32,13 +32,13 @@ class Dialect[-P <: JdbcProfile] extends AstHelpers { case None => quoteIdentifier(t.tableName) } - protected def quotedColumnNames(ns: Seq[FieldSymbol]) = ns.map(fs => quoteIdentifier(fs.name)) + protected def quotedColumnNames(ns: Seq[FieldSymbol]): Seq[String] = ns.map(fs => quoteIdentifier(fs.name)) def columnType(ci: ColumnInfo): String = ci.sqlType def autoInc(ci: ColumnInfo) = if(ci.autoInc) " AUTOINCREMENT" else "" - def primaryKey(ci: ColumnInfo, newTable: Boolean) = + def primaryKey(ci: ColumnInfo, newTable: Boolean): String = (if (newTable && ci.isPk) " PRIMARY KEY" else "") + autoInc(ci) def notNull(ci: ColumnInfo) = if (ci.notNull) " NOT NULL" else "" @@ -50,19 +50,19 @@ class Dialect[-P <: JdbcProfile] extends AstHelpers { s"$name $typ$default${ notNull(ci) }${ primaryKey(ci, newTable) }" } - def columnList(columns: Seq[FieldSymbol]) = + def columnList(columns: Seq[FieldSymbol]): String = quotedColumnNames(columns).mkString("(", ", ", ")") - def createTable(table: TableInfo, columns: Seq[ColumnInfo]): List[String] = List( + def createTable(table: TableInfo, columns: Seq[ColumnInfo], primaryKeys: Seq[PrimaryKeyInfo] = Nil): List[String] = List( s"""create table ${quoteTableName(table)} ( | ${columns map { columnSql(_, newTable = true) } mkString ", "} |)""".stripMargin - ) + ) ++ primaryKeys.map(info => createPrimaryKey(table, info.name, info.columns)) def dropTable(table: TableInfo): String = s"drop table ${quoteTableName(table)}" - def renameTable(table: TableInfo, to: String) = + def renameTable(table: TableInfo, to: String): String = s"""alter table ${quoteTableName(table)} | rename to ${quoteIdentifier(to)}""".stripMargin @@ -77,15 +77,15 @@ class Dialect[-P <: JdbcProfile] extends AstHelpers { def dropConstraint(table: TableInfo, name: String) = s"alter table ${quoteTableName(table)} drop constraint ${quoteIdentifier(name)}" - def dropForeignKey(sourceTable: TableInfo, name: String) = + def dropForeignKey(sourceTable: TableInfo, name: String): String = dropConstraint(sourceTable, name) - def createPrimaryKey(table: TableInfo, name: String, columns: Seq[FieldSymbol]) = + def createPrimaryKey(table: TableInfo, name: String, columns: Seq[FieldSymbol]): String = s"""alter table ${quoteTableName(table)} | add constraint ${quoteIdentifier(name)} primary key | ${columnList(columns)}""".stripMargin - def dropPrimaryKey(table: TableInfo, name: String) = + def dropPrimaryKey(table: TableInfo, name: String): String = dropConstraint(table, name) def createIndex(index: IndexInfo) = @@ -100,11 +100,11 @@ class Dialect[-P <: JdbcProfile] extends AstHelpers { s"alter index ${quoteIdentifier(old.name)} rename to ${quoteIdentifier(newName)}" ) - def addColumn(table: TableInfo, column: ColumnInfo) = + def addColumn(table: TableInfo, column: ColumnInfo): String = s"""alter table ${quoteTableName(table)} | add column ${columnSql(column, newTable = false)}""".stripMargin - def addColumnWithInitialValue(table: TableInfo, column: ColumnInfo, rawSqlExpr: String) = + def addColumnWithInitialValue(table: TableInfo, column: ColumnInfo, rawSqlExpr: String): List[String] = List(addColumn(table, column.copy(default = Some(rawSqlExpr)))) ++ (if (column.default.contains(rawSqlExpr)) Nil else List(alterColumnDefault(table, column))) @@ -113,7 +113,7 @@ class Dialect[-P <: JdbcProfile] extends AstHelpers { | drop column ${quoteIdentifier(column)}""".stripMargin ) - def renameColumn(table: TableInfo, from: String, to: String) = + def renameColumn(table: TableInfo, from: String, to: String): String = s"""alter table ${quoteTableName(table)} | alter column ${quoteIdentifier(from)} | rename to ${quoteIdentifier(to)}""".stripMargin @@ -126,28 +126,32 @@ class Dialect[-P <: JdbcProfile] extends AstHelpers { | set data type ${column.sqlType}""".stripMargin ) - def alterColumnDefault(table: TableInfo, column: ColumnInfo) = + def alterColumnDefault(table: TableInfo, column: ColumnInfo): String = s"""alter table ${quoteTableName(table)} | alter column ${quoteIdentifier(column.name)} | set default ${column.default getOrElse "null"}""".stripMargin - def alterColumnNullability(table: TableInfo, column: ColumnInfo) = + def alterColumnNullability(table: TableInfo, column: ColumnInfo): String = s"""alter table ${quoteTableName(table)} | alter column ${quoteIdentifier(column.name)} | ${if (column.notNull) "set" else "drop"} not null""".stripMargin - private def partition[A, B](xs: List[A])(toB: PartialFunction[A, B]): (List[B], List[A]) = - xs.foldLeft((List.empty[B], List.empty[A])) { - case ((bs, as), a) => - toB.andThen(b => (b :: bs, as)).applyOrElse(a, (_: A) => (bs, a :: as)) + private def partition(xs: List[TableMigration.Action]): (List[AddColumn], List[AddPrimaryKey], List[TableMigration.Action]) = + xs.foldLeft((List.empty[AddColumn], List.empty[AddPrimaryKey], List.empty[TableMigration.Action])) { + case ((cs, bs, as), a) => + a match { + case ac: AddColumn => (ac :: cs, bs, as) + case ap: AddPrimaryKey => (cs, ap :: bs, as) + case _ => (cs, bs, a :: as) + } } def migrateTable(table: TableInfo, actions: List[TableMigration.Action]): List[String] = { def loop(actions: List[TableMigration.Action]): List[String] = actions match { case Nil => Nil case CreateTable :: rest => - val (cols, other) = partition(rest) { case a: AddColumn => a } - createTable(table, cols.map(_.info)) ::: loop(other) + val (cols, pKeys, other) = partition(rest) + createTable(table, cols.map(_.info), pKeys.map(_.info)) ::: loop(other) case AlterColumnType(info) :: rest => alterColumnType(table, info) ::: loop(rest) case DropTable :: rest => dropTable(table) :: loop(rest) case RenameTableTo(to) :: rest => renameTable(table, to) :: loop(rest) @@ -189,7 +193,7 @@ class DerbyDialect extends Dialect[DerbyProfile] { override def autoInc(ci: ColumnInfo) = if(ci.autoInc) " GENERATED BY DEFAULT AS IDENTITY" else "" - override def alterColumnType(table: TableInfo, column: ColumnInfo) = { + override def alterColumnType(table: TableInfo, column: ColumnInfo): List[String] = { val tmpColumnName = "temp_column"+(math.random*1000000).toInt val tmpColumn = column.copy(name = tmpColumnName) @@ -202,7 +206,7 @@ class DerbyDialect extends Dialect[DerbyProfile] { override def renameColumn(table: TableInfo, from: String, to: String) = s"rename column ${quoteTableName(table)}.${quoteIdentifier(from)} to ${quoteIdentifier(to)}" - override def alterColumnNullability(table: TableInfo, column: ColumnInfo) = + override def alterColumnNullability(table: TableInfo, column: ColumnInfo): String = s"""alter table ${quoteTableName(table)} | alter column ${quoteIdentifier(column.name)} | ${if (column.notNull) "not" else ""} null""".stripMargin @@ -210,7 +214,7 @@ class DerbyDialect extends Dialect[DerbyProfile] { override def renameTable(table: TableInfo, to: String) = s"rename table ${quoteTableName(table)} to ${quoteIdentifier(to)}" - override def renameIndex(old: IndexInfo, newName: String) = List( + override def renameIndex(old: IndexInfo, newName: String): List[String] = List( s"rename index ${quoteIdentifier(old.name)} to ${quoteIdentifier(newName)}" ) } @@ -234,7 +238,7 @@ class SQLiteDialect extends Dialect[SQLiteProfile] with SimulatedRenameIndex[SQL class HsqldbDialect extends Dialect[HsqldbProfile] { override def autoInc(ci: ColumnInfo) = if(ci.autoInc) " GENERATED BY DEFAULT AS IDENTITY" else "" - override def primaryKey(ci: ColumnInfo, newTable: Boolean) = + override def primaryKey(ci: ColumnInfo, newTable: Boolean): String = autoInc(ci) + (if (newTable && ci.isPk) " PRIMARY KEY" else "") override def notNull(ci: ColumnInfo) = if (ci.notNull && !ci.isPk) " NOT NULL" else "" @@ -258,23 +262,23 @@ class MySQLDialect extends Dialect[MySQLProfile] with SimulatedRenameIndex[MySQL override def renameColumn(table: TableInfo, from: String, to: String) = s"ALTER TABLE ${quoteTableName(table)} RENAME COLUMN ${quoteIdentifier(from)} TO ${quoteIdentifier(to)}" - override def renameColumn(table: TableInfo, from: ColumnInfo, to: String) = { + override def renameColumn(table: TableInfo, from: ColumnInfo, to: String): String = { val newCol = from.copy(name = to) s"""alter table ${quoteTableName(table)} | change ${quoteIdentifier(from.name)} | ${columnSql(newCol, newTable = false)}""".stripMargin } - override def alterColumnNullability(table: TableInfo, column: ColumnInfo) = + override def alterColumnNullability(table: TableInfo, column: ColumnInfo): String = renameColumn(table, column, column.name) - override def alterColumnType(table: TableInfo, column: ColumnInfo) = + override def alterColumnType(table: TableInfo, column: ColumnInfo): List[String] = List(renameColumn(table, column, column.name)) override def dropForeignKey(table: TableInfo, name: String) = s"alter table ${quoteTableName(table)} drop foreign key ${quoteIdentifier(name)}" - override def createPrimaryKey(table: TableInfo, name: String, columns: Seq[FieldSymbol]) = + override def createPrimaryKey(table: TableInfo, name: String, columns: Seq[FieldSymbol]): String = s"""alter table ${quoteTableName(table)} | add constraint primary key | ${columnList(columns)}""".stripMargin @@ -291,8 +295,16 @@ class PostgresDialect extends Dialect[PostgresProfile] { case (true, "BIGINT") => "BIGSERIAL" case (true, _) => throw new RuntimeException("Unsupported autoincrement type") } + + override def createTable(table: TableInfo, columns: Seq[ColumnInfo], primaryKeys: Seq[PrimaryKeyInfo] = Nil): List[String] = List( + s"""create table ${quoteTableName(table)} ( + | ${columns map { columnSql(_, newTable = true) } mkString("", ", ", if (columns.nonEmpty && primaryKeys.nonEmpty) "," else "")} + | ${if (primaryKeys.nonEmpty) primaryKeys.map{ ci => quoteIdentifier(ci.name) }.mkString("primary key (", ", ", ")") else ""} + |)""".stripMargin + ) + override def autoInc(ci: ColumnInfo) = "" - override def renameColumn(table: TableInfo, from: String, to: String) = + override def renameColumn(table: TableInfo, from: String, to: String): String = s"""alter table ${quoteTableName(table)} | rename column ${quoteIdentifier(from)} | to ${quoteIdentifier(to)}""".stripMargin @@ -300,8 +312,8 @@ class PostgresDialect extends Dialect[PostgresProfile] { class OracleDialect extends Dialect[OracleProfile] { - override def createTable(table: TableInfo, columns: Seq[ColumnInfo]): List[String] = { - super.createTable(table, columns) ++ + override def createTable(table: TableInfo, columns: Seq[ColumnInfo], primaryKeys: Seq[PrimaryKeyInfo] = Nil): List[String] = { + super.createTable(table, columns, primaryKeys) ++ columns.filter(_.autoInc).flatMap(addAutoInc(table, _, 1L)) }