@@ -163,40 +163,46 @@ abstract class WeakHashSet[A <: AnyRef](initialCapacity: Int = 8, loadFactor: Do
163163 elem
164164 }
165165
166- def put (elem : A ): A = {
167- Stats .record(statsItem(" put" ))
168- removeStaleEntries()
169- val h = hash(elem)
170- val bucket = index(h)
171- val oldHead = table(bucket)
166+ // TODO: remove the `case null` when we can enable explicit nulls in regular compiling,
167+ // since the type `A <: AnyRef` of `elem` can ensure the value is not null.
168+ def put (elem : A ): A = (elem : A | Null ) match {
169+ case null => throw new NullPointerException (" WeakHashSet cannot hold nulls" )
170+ case _ =>
171+ Stats .record(statsItem(" put" ))
172+ removeStaleEntries()
173+ val h = hash(elem)
174+ val bucket = index(h)
175+ val oldHead = table(bucket)
172176
173- @ tailrec
174- def linkedListLoop (entry : Entry [A ] | Null ): A = entry match {
175- case null => addEntryAt(bucket, elem, h, oldHead)
176- case _ =>
177- val entryElem = entry.get
178- if entryElem != null && isEqual(elem, entryElem) then entryElem.uncheckedNN
179- else linkedListLoop(entry.tail)
180- }
177+ @ tailrec
178+ def linkedListLoop (entry : Entry [A ] | Null ): A = entry match {
179+ case null => addEntryAt(bucket, elem, h, oldHead)
180+ case _ =>
181+ val entryElem = entry.get
182+ if entryElem != null && isEqual(elem, entryElem) then entryElem.uncheckedNN
183+ else linkedListLoop(entry.tail)
184+ }
181185
182- linkedListLoop(oldHead)
186+ linkedListLoop(oldHead)
183187 }
184188
185189 def += (elem : A ): Unit = put(elem)
186190
187- def -= (elem : A ): Unit = {
188- Stats .record(statsItem(" -=" ))
189- removeStaleEntries()
190- val bucket = index(hash(elem))
191+ def -= (elem : A ): Unit = (elem : A | Null ) match {
192+ case null =>
193+ case _ =>
194+ Stats .record(statsItem(" -=" ))
195+ removeStaleEntries()
196+ val bucket = index(hash(elem))
191197
192- @ tailrec
193- def linkedListLoop (prevEntry : Entry [A ] | Null , entry : Entry [A ] | Null ): Unit =
194- if entry != null then
195- val entryElem = entry.get
196- if entryElem != null && isEqual(elem, entryElem) then remove(bucket, prevEntry, entry)
197- else linkedListLoop(entry, entry.tail)
198+ @ tailrec
199+ def linkedListLoop (prevEntry : Entry [A ] | Null , entry : Entry [A ] | Null ): Unit =
200+ if entry != null then
201+ val entryElem = entry.get
202+ if entryElem != null && isEqual(elem, entryElem) then remove(bucket, prevEntry, entry)
203+ else linkedListLoop(entry, entry.tail)
198204
199- linkedListLoop(null , table(bucket))
205+ linkedListLoop(null , table(bucket))
200206 }
201207
202208 def clear (): Unit = {
@@ -255,14 +261,13 @@ abstract class WeakHashSet[A <: AnyRef](initialCapacity: Int = 8, loadFactor: Do
255261 }
256262
257263 def next (): A =
258- if ( lookaheadelement == null )
264+ if lookaheadelement == null then
259265 throw new IndexOutOfBoundsException (" next on an empty iterator" )
260- else {
266+ else
261267 val result = lookaheadelement.nn
262268 lookaheadelement = null
263269 entry = entry.nn.tail
264270 result
265- }
266271 }
267272 }
268273
0 commit comments