@@ -15,12 +15,15 @@ import Trees.*
1515import Types .*
1616import Symbols .*
1717import Names .*
18+ import StdNames .str
1819import NameOps .*
1920import inlines .Inlines
2021import transform .ValueClasses
21- import dotty .tools .io .{File , FileExtension }
22+ import dotty .tools .io .{File , FileExtension , JarArchive }
23+ import util .{Property , SourceFile }
2224import java .io .PrintWriter
2325
26+ import ExtractAPI .NonLocalClassSymbolsInCurrentUnits
2427
2528import scala .collection .mutable
2629import scala .util .hashing .MurmurHash3
@@ -64,13 +67,62 @@ class ExtractAPI extends Phase {
6467 // definitions, and `PostTyper` does not change definitions).
6568 override def runsAfter : Set [String ] = Set (transform.PostTyper .name)
6669
70+ override def runOn (units : List [CompilationUnit ])(using Context ): List [CompilationUnit ] =
71+ val nonLocalClassSymbols = new mutable.HashSet [Symbol ]
72+ val ctx0 = ctx.withProperty(NonLocalClassSymbolsInCurrentUnits , Some (nonLocalClassSymbols))
73+ val units0 = super .runOn(units)(using ctx0)
74+ ctx.withIncCallback(recordNonLocalClasses(nonLocalClassSymbols, _))
75+ units0
76+ end runOn
77+
78+ private def recordNonLocalClasses (nonLocalClassSymbols : mutable.HashSet [Symbol ], cb : interfaces.IncrementalCallback )(using Context ): Unit =
79+ for cls <- nonLocalClassSymbols do
80+ val sourceFile = cls.source
81+ if sourceFile.exists && cls.isDefinedInCurrentRun then
82+ recordNonLocalClass(cls, sourceFile, cb)
83+ cb.apiPhaseCompleted()
84+ cb.dependencyPhaseCompleted()
85+
86+ private def recordNonLocalClass (cls : Symbol , sourceFile : SourceFile , cb : interfaces.IncrementalCallback )(using Context ): Unit =
87+ def registerProductNames (fullClassName : String , binaryClassName : String ) =
88+ val pathToClassFile = s " ${binaryClassName.replace('.' , java.io.File .separatorChar)}.class "
89+
90+ val classFile = {
91+ ctx.settings.outputDir.value match {
92+ case jar : JarArchive =>
93+ // important detail here, even on Windows, Zinc expects the separator within the jar
94+ // to be the system default, (even if in the actual jar file the entry always uses '/').
95+ // see https://github.com/sbt/zinc/blob/dcddc1f9cfe542d738582c43f4840e17c053ce81/internal/compiler-bridge/src/main/scala/xsbt/JarUtils.scala#L47
96+ new java.io.File (s " $jar! $pathToClassFile" )
97+ case outputDir =>
98+ new java.io.File (outputDir.file, pathToClassFile)
99+ }
100+ }
101+
102+ cb.generatedNonLocalClass(sourceFile, classFile.toPath(), binaryClassName, fullClassName)
103+ end registerProductNames
104+
105+ val fullClassName = atPhase(sbtExtractDependenciesPhase) {
106+ ExtractDependencies .classNameAsString(cls)
107+ }
108+ val binaryClassName = cls.binaryClassName
109+ registerProductNames(fullClassName, binaryClassName)
110+
111+ // Register the names of top-level module symbols that emit two class files
112+ val isTopLevelUniqueModule =
113+ cls.owner.is(PackageClass ) && cls.is(ModuleClass ) && cls.companionClass == NoSymbol
114+ if isTopLevelUniqueModule then
115+ registerProductNames(fullClassName, binaryClassName.stripSuffix(str.MODULE_SUFFIX ))
116+ end recordNonLocalClass
117+
67118 override def run (using Context ): Unit = {
68119 val unit = ctx.compilationUnit
69120 val sourceFile = unit.source
70121 ctx.withIncCallback: cb =>
71122 cb.startSource(sourceFile)
72123
73- val apiTraverser = new ExtractAPICollector
124+ val nonLocalClassSymbols = ctx.property(NonLocalClassSymbolsInCurrentUnits ).get
125+ val apiTraverser = ExtractAPICollector (nonLocalClassSymbols)
74126 val classes = apiTraverser.apiSource(unit.tpdTree)
75127 val mainClasses = apiTraverser.mainClasses
76128
@@ -94,6 +146,8 @@ object ExtractAPI:
94146 val name : String = " sbt-api"
95147 val description : String = " sends a representation of the API of classes to sbt"
96148
149+ private val NonLocalClassSymbolsInCurrentUnits : Property .Key [mutable.HashSet [Symbol ]] = Property .Key ()
150+
97151/** Extracts full (including private members) API representation out of Symbols and Types.
98152 *
99153 * The exact representation used for each type is not important: the only thing
@@ -136,7 +190,7 @@ object ExtractAPI:
136190 * without going through an intermediate representation, see
137191 * http://www.scala-sbt.org/0.13/docs/Understanding-Recompilation.html#Hashing+an+API+representation
138192 */
139- private class ExtractAPICollector (using Context ) extends ThunkHolder {
193+ private class ExtractAPICollector (nonLocalClassSymbols : mutable. HashSet [ Symbol ])( using Context ) extends ThunkHolder {
140194 import tpd .*
141195 import xsbti .api
142196
@@ -254,6 +308,8 @@ private class ExtractAPICollector(using Context) extends ThunkHolder {
254308 childrenOfSealedClass, topLevel, tparams)
255309
256310 allNonLocalClassesInSrc += cl
311+ if ! sym.isLocal then
312+ nonLocalClassSymbols += sym
257313
258314 if (sym.isStatic && ! sym.is(Trait ) && ctx.platform.hasMainMethod(sym)) {
259315 // If sym is an object, all main methods count, otherwise only @static ones count.
0 commit comments