diff --git a/core/src/main/scala/cats/FlatMap.scala b/core/src/main/scala/cats/FlatMap.scala index fe087077d3..9bb64722cd 100644 --- a/core/src/main/scala/cats/FlatMap.scala +++ b/core/src/main/scala/cats/FlatMap.scala @@ -36,7 +36,7 @@ package cats * * Must obey the laws defined in cats.laws.FlatMapLaws. */ -trait FlatMap[F[_]] extends Apply[F] { +trait FlatMap[F[_]] extends Apply[F] with FlatMapArityFunctions[F] { def flatMap[A, B](fa: F[A])(f: A => F[B]): F[B] /** diff --git a/project/Boilerplate.scala b/project/Boilerplate.scala index 9929241083..06348ce071 100644 --- a/project/Boilerplate.scala +++ b/project/Boilerplate.scala @@ -27,6 +27,7 @@ object Boilerplate { GenSemigroupalBuilders, GenSemigroupalArityFunctions, GenApplyArityFunctions, + GenFlatMapArityFunctions, GenTupleSemigroupalSyntax, GenParallelArityFunctions, GenParallelArityFunctions2, @@ -285,6 +286,36 @@ object Boilerplate { } } + object GenFlatMapArityFunctions extends Template { + def filename(root: File) = root / "cats" / "FlatMapArityFunctions.scala" + override def range = 2 to maxArity + def content(tv: TemplateVals) = { + import tv._ + + val tpes = synTypes.map { tpe => + s"F[$tpe]" + } + val fargs = (0 until arity).map("f" + _) + val fparams = fargs.zip(tpes).map { case (v, t) => s"$v:$t" }.mkString(", ") + + block""" + |package cats + | + |/** + | * @groupprio Ungrouped 0 + | * + | * @groupname FlatMapArity flatMap arity + | * @groupdesc FlatMapArity Higher-arity flatMap methods + | * @groupprio FlatMapArity 999 + | */ + |trait FlatMapArityFunctions[F[_]] { self: FlatMap[F] => + - /** @group FlatMapArity */ + - def flatMap$arity[${`A..N`}, Z]($fparams)(f: (${`A..N`}) => F[Z]): F[Z] = self.flatten(self.map$arity($fparams)(f)) + |} + """ + } + } + final case class ParallelNestedExpansions(arity: Int) { val products = (0 until (arity - 2)) .foldRight(s"Parallel.parProduct(m${arity - 2}, m${arity - 1})")((i, acc) => s"Parallel.parProduct(m$i, $acc)") @@ -516,6 +547,12 @@ object Boilerplate { else s"def traverseN[G[_]: Applicative, Z](f: (${`A..N`}) => G[Z])(implicit traverse: Traverse[F], semigroupal: Semigroupal[F]): G[F[Z]] = Semigroupal.traverse$arity($tupleArgs)(f)" + val flatMap = + if (arity == 1) + s"def flatMap[Z](f: (${`A..N`}) => F[Z])(implicit flatMap: FlatMap[F]): F[Z] = flatMap.flatMap($tupleArgs)(f)" + else + s"def flatMapN[Z](f: (${`A..N`}) => F[Z])(implicit flatMap: FlatMap[F]): F[Z] = flatMap.flatMap$arity($tupleArgs)(f)" + block""" |package cats |package syntax @@ -528,6 +565,7 @@ object Boilerplate { - $map - $contramap - $imap + - $flatMap - $tupled - $traverse - def apWith[Z](f: F[(${`A..N`}) => Z])(implicit apply: Apply[F]): F[Z] = apply.ap$n(f)($tupleArgs) diff --git a/tests/shared/src/test/scala/cats/tests/SyntaxSuite.scala b/tests/shared/src/test/scala/cats/tests/SyntaxSuite.scala index 66c2fb9738..1700a201fa 100644 --- a/tests/shared/src/test/scala/cats/tests/SyntaxSuite.scala +++ b/tests/shared/src/test/scala/cats/tests/SyntaxSuite.scala @@ -412,10 +412,19 @@ object SyntaxSuite { val fa = a.pure[F] } - def testFlatMap[F[_]: FlatMap, A, B, C, D]: Unit = { + def testFlatMap[F[_]: FlatMap, A, B, C, D, Z]: Unit = { val a = mock[A] val returnValue = mock[F[Either[A, B]]] val done = a.tailRecM[F, B](a => returnValue) + val tfabc = mock[(F[A], F[B], F[C])] + val ff = mock[(A, B, C) => F[Z]] + + tfabc.flatMapN(ff) + + val tfa = mock[Tuple1[F[A]]] + val ffone = mock[A => F[Z]] + + tfa.flatMap(ffone) val x = mock[Function[A, F[B]]] val y = mock[Function[B, F[C]]]