@@ -61,8 +61,7 @@ abstract class WeakHashSet[A <: AnyRef](initialCapacity: Int = 8, loadFactor: Do
6161 private def computeThreshold : Int = (table.size * loadFactor).ceil.toInt
6262
6363 protected def hash (key : A ): Int
64- protected def isEqual (x : A | Null , y : A | Null ): Boolean =
65- if x == null then y == null else x.equals(y)
64+ protected def isEqual (x : A , y : A ): Boolean = x.equals(y)
6665
6766 /** Turn hashcode `x` into a table index */
6867 protected def index (x : Int ): Int = x & (table.length - 1 )
@@ -135,24 +134,25 @@ abstract class WeakHashSet[A <: AnyRef](initialCapacity: Int = 8, loadFactor: Do
135134 tableLoop(0 )
136135 }
137136
138- def lookup (elem : A ): A | Null = {
139- // case null => throw new NullPointerException("WeakHashSet cannot hold nulls")
140- // case _ =>
137+ // TODO: remove the `case null` when we can enable explicit nulls in regular compiling,
138+ // since the type `A <: AnyRef` of `elem` can ensure the value is not null.
139+ def lookup (elem : A ): A | Null = (elem : A | Null ) match {
140+ case null => throw new NullPointerException (" WeakHashSet cannot hold nulls" )
141+ case _ =>
142+ Stats .record(statsItem(" lookup" ))
143+ removeStaleEntries()
144+ val bucket = index(hash(elem))
141145
142- Stats .record(statsItem(" lookup" ))
143- removeStaleEntries()
144- val bucket = index(hash(elem))
145-
146- @ tailrec
147- def linkedListLoop (entry : Entry [A ] | Null ): A | Null = entry match {
148- case null => null
149- case _ =>
150- val entryElem = entry.get
151- if (isEqual(elem, entryElem)) entryElem
152- else linkedListLoop(entry.tail)
153- }
146+ @ tailrec
147+ def linkedListLoop (entry : Entry [A ] | Null ): A | Null = entry match {
148+ case null => null
149+ case _ =>
150+ val entryElem = entry.get
151+ if entryElem != null && isEqual(elem, entryElem) then entryElem
152+ else linkedListLoop(entry.tail)
153+ }
154154
155- linkedListLoop(table(bucket))
155+ linkedListLoop(table(bucket))
156156 }
157157
158158 protected def addEntryAt (bucket : Int , elem : A , elemHash : Int , oldHead : Entry [A ] | Null ): A = {
@@ -175,7 +175,7 @@ abstract class WeakHashSet[A <: AnyRef](initialCapacity: Int = 8, loadFactor: Do
175175 case null => addEntryAt(bucket, elem, h, oldHead)
176176 case _ =>
177177 val entryElem = entry.get
178- if ( isEqual(elem, entryElem)) entryElem.uncheckedNN
178+ if entryElem != null && isEqual(elem, entryElem) then entryElem.uncheckedNN
179179 else linkedListLoop(entry.tail)
180180 }
181181
@@ -192,7 +192,8 @@ abstract class WeakHashSet[A <: AnyRef](initialCapacity: Int = 8, loadFactor: Do
192192 @ tailrec
193193 def linkedListLoop (prevEntry : Entry [A ] | Null , entry : Entry [A ] | Null ): Unit =
194194 if entry != null then
195- if isEqual(elem, entry.get) then remove(bucket, prevEntry, entry)
195+ val entryElem = entry.get
196+ if entryElem != null && isEqual(elem, entryElem) then remove(bucket, prevEntry, entry)
196197 else linkedListLoop(entry, entry.tail)
197198
198199 linkedListLoop(null , table(bucket))
0 commit comments