Skip to content

Commit 2c77412

Browse files
committed
Syntax for polymorphic values.
∀[T].m[A, B[_], ...](e) is rewritten to new T { def m[A, B[_], ...] = e } In case `m` has only one *-kinded type parameter which is is not referenced from `e`, it can be omitted: ∀[T].m(e) is rewritten to new T { def m[A] = e } where `A` is a fresh name. If, in addition, the method name is `apply`, it can be omitted as well: ∀[T](e) is rewritten to new T { def apply[A] = e } where `A` is a fresh name.
1 parent a974f51 commit 2c77412

4 files changed

Lines changed: 174 additions & 59 deletions

File tree

src/main/scala/Extractors.scala

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,4 +41,18 @@ trait Extractors {
4141
case _ => None
4242
}
4343
}
44+
object TermForall {
45+
def unapply(tree: Tree): Option[Tree] = tree match {
46+
case TypeApply(Ident(TermName("$u2200")), tpe :: Nil) => Some(tpe) // ∀[F]
47+
case _ => None
48+
}
49+
}
50+
object PolyVal {
51+
def unapply(tree: Tree): Option[(Tree, TermName, List[Tree], Tree)] = tree match {
52+
case Apply(TypeApply(Select(TermForall(tpe), method), tParams), body :: Nil) => Some((tpe, method.toTermName, tParams, body))
53+
case Apply(Select(TermForall(tpe), method), body :: Nil) => Some((tpe, method.toTermName, Nil, body))
54+
case Apply(TermForall(tpe), body :: Nil) => Some((tpe, nme.apply, Nil, body))
55+
case _ => None
56+
}
57+
}
4458
}

src/main/scala/KindProjector.scala

Lines changed: 64 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,59 @@ class KindRewriter(plugin: Plugin, val global: Global)
128128
def makeTypeParamContra(name: Name, bounds: TypeBoundsTree = DefaultBounds): TypeDef =
129129
TypeDef(Modifiers(PARAM | CONTRAVARIANT), makeTypeName(name), Nil, bounds)
130130

131-
def polyLambda(tree: Tree): Tree = tree match {
131+
// Given a name, e.g. A or `+A` or `A <: Foo`, build a type
132+
// parameter tree using the given name, bounds, variance, etc.
133+
def makeTypeParamFromName(ident: Ident): TypeDef = {
134+
val decoded = NameTransformer.decode(ident.name.toString)
135+
val src = s"type _X_[$decoded] = Unit"
136+
sp.parse(src) match {
137+
case Some(TypeDef(_, _, List(tpe), _)) => tpe.duplicate
138+
case None => reporter.error(ident.pos, s"Can't parse param: ${ident.name}"); null
139+
}
140+
}
141+
142+
// Like makeTypeParam, but can be used recursively in the case of types
143+
// that are themselves parameterized.
144+
def makeComplexTypeParam(t: Tree): TypeDef = t match {
145+
case id @ Ident(_) =>
146+
makeTypeParamFromName(id)
147+
148+
case TypeDef(m, nm, ps, bs) =>
149+
TypeDef(Modifiers(PARAM), nm, ps.map(makeComplexTypeParam), bs)
150+
151+
case ExistentialTypeTree(AppliedTypeTree(Ident(name), ps), _) =>
152+
val tparams = ps.map(makeComplexTypeParam)
153+
TypeDef(Modifiers(PARAM), makeTypeName(name), tparams, DefaultBounds)
154+
155+
case x =>
156+
reporter.error(x.pos, "Can't parse %s (%s)" format (x, x.getClass.getName))
157+
null.asInstanceOf[TypeDef]
158+
}
159+
160+
def typeArgsToTypeParams(args: List[Tree]): List[TypeDef] = args.map {
161+
case id @ Ident(_) =>
162+
makeTypeParamFromName(id)
163+
164+
case AppliedTypeTree(Ident(Plus), Ident(name) :: Nil) =>
165+
makeTypeParamCo(name)
166+
167+
case AppliedTypeTree(Ident(Minus), Ident(name) :: Nil) =>
168+
makeTypeParamContra(name)
169+
170+
case AppliedTypeTree(Ident(name), ps) =>
171+
val tparams = ps.map(makeComplexTypeParam)
172+
TypeDef(Modifiers(PARAM), makeTypeName(name), tparams, DefaultBounds)
173+
174+
case ExistentialTypeTree(AppliedTypeTree(Ident(name), ps), _) =>
175+
val tparams = ps.map(makeComplexTypeParam)
176+
TypeDef(Modifiers(PARAM), makeTypeName(name), tparams, DefaultBounds)
177+
178+
case x =>
179+
reporter.error(x.pos, "Can't parse %s (%s)" format (x, x.getClass.getName))
180+
null.asInstanceOf[TypeDef]
181+
}
182+
183+
def polyTerm(tree: Tree): Tree = tree match {
132184
case PolyLambda(methodName, (arrowType @ UnappliedType(_ :: targs)) :: Nil, Function1Tree(name, body)) =>
133185
val (f, g) = targs match {
134186
case a :: b :: Nil => (a, b)
@@ -139,41 +191,21 @@ class KindRewriter(plugin: Plugin, val global: Global)
139191
atPos(tree.pos.makeTransparent)(
140192
q"new $arrowType { def $methodName[$TParam]($name: $f[$TParam]): $g[$TParam] = $body }"
141193
)
194+
case PolyVal(targetType, methodName, tArgs, body) =>
195+
atPos(tree.pos.makeTransparent)(tArgs match {
196+
case Nil =>
197+
val tParam = freshTypeName("A")(currentFreshNameCreator)
198+
q"new $targetType { def $methodName[$tParam] = $body }"
199+
case _ =>
200+
val tParams = typeArgsToTypeParams(tArgs)
201+
q"new $targetType { def $methodName[..$tParams] = $body }"
202+
})
142203
case _ => tree
143204
}
144205

145206
// The transform method -- this is where the magic happens.
146207
override def transform(tree: Tree): Tree = {
147208

148-
// Given a name, e.g. A or `+A` or `A <: Foo`, build a type
149-
// parameter tree using the given name, bounds, variance, etc.
150-
def makeTypeParamFromName(name: Name): TypeDef = {
151-
val decoded = NameTransformer.decode(name.toString)
152-
val src = s"type _X_[$decoded] = Unit"
153-
sp.parse(src) match {
154-
case Some(TypeDef(_, _, List(tpe), _)) => tpe
155-
case None => reporter.error(tree.pos, s"Can't parse param: $name"); null
156-
}
157-
}
158-
159-
// Like makeTypeParam, but can be used recursively in the case of types
160-
// that are themselves parameterized.
161-
def makeComplexTypeParam(t: Tree): TypeDef = t match {
162-
case Ident(name) =>
163-
makeTypeParamFromName(name)
164-
165-
case TypeDef(m, nm, ps, bs) =>
166-
TypeDef(Modifiers(PARAM), nm, ps.map(makeComplexTypeParam), bs)
167-
168-
case ExistentialTypeTree(AppliedTypeTree(Ident(name), ps), _) =>
169-
val tparams = ps.map(makeComplexTypeParam)
170-
TypeDef(Modifiers(PARAM), makeTypeName(name), tparams, DefaultBounds)
171-
172-
case x =>
173-
reporter.error(x.pos, "Can't parse %s (%s)" format (x, x.getClass.getName))
174-
null.asInstanceOf[TypeDef]
175-
}
176-
177209
// Given the list a::as, this method finds the last argument in the list
178210
// (the "subtree") and returns that separately from the other arguments.
179211
// The stack is just used to enable tail recursion, and a and as are
@@ -204,28 +236,7 @@ class KindRewriter(plugin: Plugin, val global: Global)
204236
// Lambda[(A, B) => Function2[A, Int, B]] case.
205237
def handleLambda(a: Tree, as: List[Tree]): Tree = {
206238
val (args, subtree) = parseLambda(a, as, Nil)
207-
val innerTypes = args.map {
208-
case Ident(name) =>
209-
makeTypeParamFromName(name)
210-
211-
case AppliedTypeTree(Ident(Plus), Ident(name) :: Nil) =>
212-
makeTypeParamCo(name)
213-
214-
case AppliedTypeTree(Ident(Minus), Ident(name) :: Nil) =>
215-
makeTypeParamContra(name)
216-
217-
case AppliedTypeTree(Ident(name), ps) =>
218-
val tparams = ps.map(makeComplexTypeParam)
219-
TypeDef(Modifiers(PARAM), makeTypeName(name), tparams, DefaultBounds)
220-
221-
case ExistentialTypeTree(AppliedTypeTree(Ident(name), ps), _) =>
222-
val tparams = ps.map(makeComplexTypeParam)
223-
TypeDef(Modifiers(PARAM), makeTypeName(name), tparams, DefaultBounds)
224-
225-
case x =>
226-
reporter.error(x.pos, "Can't parse %s (%s)" format (x, x.getClass.getName))
227-
null.asInstanceOf[TypeDef]
228-
}
239+
val innerTypes = typeArgsToTypeParams(args)
229240
makeTypeProjection(innerTypes, subtree)
230241
}
231242

@@ -321,7 +332,7 @@ class KindRewriter(plugin: Plugin, val global: Global)
321332
// given a tree, see if it could possibly be a type lambda
322333
// (either placeholder syntax or lambda syntax). if so, handle
323334
// it, and if not, transform it in the normal way.
324-
val result = polyLambda(tree match {
335+
val result = polyTerm(tree match {
325336

326337
// Lambda[A => Either[A, Int]] case.
327338
case AppliedTypeTree(Ident(TypeLambda1), AppliedTypeTree(target, a :: as) :: Nil) =>

src/test/scala/polylambda.scala

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,18 @@ package d_m
22

33
import org.junit.Test
44

5-
trait ~>[-F[_], +G[_]] {
6-
def apply[A](x: F[A]): G[A]
7-
}
8-
trait ~>>[-F[_], +G[_]] {
9-
def dingo[B](x: F[B]): G[B]
10-
}
115
final case class Const[A, B](getConst: A)
126

137
class PolyLambdas {
8+
9+
trait ~>[-F[_], +G[_]] {
10+
def apply[A](x: F[A]): G[A]
11+
}
12+
13+
trait ~>>[-F[_], +G[_]] {
14+
def dingo[B](x: F[B]): G[B]
15+
}
16+
1417
type ToSelf[F[_]] = F ~> F
1518

1619
val kf1 = Lambda[Option ~> Vector](_.toVector)

src/test/scala/polyval.scala

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
package d_m
2+
3+
import org.junit.Test
4+
5+
class PolyVals {
6+
7+
trait Forall[F[_]] {
8+
def apply[A]: F[A]
9+
}
10+
11+
trait Forall2[F[_, _]] {
12+
def apply[A, B]: F[A, B]
13+
}
14+
15+
trait ForallK[F[_[_]]] {
16+
def apply[G[_]]: F[G]
17+
}
18+
19+
20+
trait Semigroup[A] {
21+
def combine(x: A, y: A): A
22+
}
23+
24+
def listSemigroup[A]: Semigroup[List[A]] = new Semigroup[List[A]] {
25+
def combine(x: List[A], y: List[A]): List[A] = x ++ y
26+
}
27+
28+
trait Functor[F[_]]
29+
trait Monad[F[_]] {
30+
def functor: Functor[F]
31+
}
32+
33+
// universally quantified semigroup
34+
type SemigroupK[F[_]] = Forall[λ[α => Semigroup[F[α]]]]
35+
36+
// natural transformations
37+
type ~>[F[_] , G[_] ] = Forall [λ[α => F[α] => G[α]]]
38+
type ≈>[F[_[_]], G[_[_]]] = ForallK[λ[α[_] => F[α] => G[α]]]
39+
40+
// Const functor and constructors
41+
type ConstA[A] = Forall[Const[A, ?]]
42+
type ConstMaker = Forall[λ[α => α => ConstA[α]]]
43+
type ConstMaker2 = Forall2[λ[(α, β) => α => Const[α, β]]]
44+
45+
// existentials via universals
46+
type Consumer[F[_], R] = Forall[λ[A => F[A] => R]]
47+
type Exists[F[_]] = Forall[λ[R => Consumer[F, R] => R]]
48+
def existential[F[_], A](fa: F[A]): Exists[F] = [Exists[F]](_[A](fa))
49+
50+
@Test
51+
def testSemigroupK(): Unit = {
52+
val listSemigroupK = [SemigroupK[List]](listSemigroup)
53+
assert(listSemigroupK[Int].combine(List(1, 2), List(3, 4)) == List(1, 2, 3, 4))
54+
}
55+
56+
@Test
57+
def testNaturalTransformations(): Unit = {
58+
val headOption = [List ~> Option].apply[A](_.headOption)
59+
val monadToFunctor1 = [Monad ≈> Functor].apply[F[_]](_.functor)
60+
val monadToFunctor2 = [Monad ≈> Functor].apply[F[_]]((m: Monad[F]) => m.functor)
61+
62+
val listFunctor = new Functor[List] {}
63+
val listMonad = new Monad[List] { def functor = listFunctor }
64+
65+
assert(headOption[Int](List(1, 2)) == Some(1))
66+
assert(monadToFunctor1[List](listMonad) == listFunctor)
67+
assert(monadToFunctor2[List](listMonad) == listFunctor)
68+
}
69+
70+
@Test
71+
def testConst(): Unit = {
72+
val const42 = [ConstA[Int]].apply[B](new Const[Int, B](42))
73+
val constMaker = [ConstMaker] .apply[A] (a => [ConstA[A]].apply[B](new Const[A, B](a)))
74+
val constMaker2 = [ConstMaker2].apply[A, B](a => new Const[A, B](a) )
75+
76+
assert(const42[String].getConst == 42)
77+
assert(constMaker[Int](42)[String].getConst == 42)
78+
assert(constMaker2[Int, String](42).getConst == 42)
79+
}
80+
81+
@Test
82+
def testExistential(): Unit = {
83+
val list = existential(List("one", "two", "three"))
84+
val len = [Consumer[List, Int]](_.length)
85+
assert(list[Int](len) == 3)
86+
}
87+
}

0 commit comments

Comments
 (0)