Skip to content

Commit 6401ae1

Browse files
committed
[no-master] Fix scala-js/scala-js#3408: Survive concurrent calls to send and receive.
1 parent dcd5ed3 commit 6401ae1

File tree

2 files changed

+94
-19
lines changed

2 files changed

+94
-19
lines changed

js-envs-test-suite/src/test/scala/org/scalajs/jsenv/test/NodeJSTest.scala

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,4 +76,35 @@ class NodeJSTest extends TimeoutComTests {
7676
com.await(DefaultTimeout)
7777
}
7878

79+
@Test
80+
def testConcurrentSendReceive_issue3408: Unit = {
81+
for (_ <- 0 until 50) {
82+
val com = comRunner("""
83+
scalajsCom.init(function(msg) {
84+
scalajsCom.send("pong: " + msg);
85+
});
86+
""")
87+
88+
start(com)
89+
90+
// Try very hard to send and receive at the same time
91+
val lock = new AnyRef
92+
val threadSend = new Thread {
93+
override def run(): Unit = {
94+
lock.synchronized(lock.wait())
95+
com.send("ping")
96+
}
97+
}
98+
threadSend.start()
99+
100+
Thread.sleep(200L)
101+
lock.synchronized(lock.notifyAll())
102+
assertEquals(com.receive(), "pong: ping")
103+
104+
threadSend.join()
105+
com.close()
106+
com.await(DefaultTimeout)
107+
}
108+
}
109+
79110
}

js-envs/src/main/scala/org/scalajs/jsenv/nodejs/AbstractNodeJSEnv.scala

Lines changed: 63 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99

1010
package org.scalajs.jsenv.nodejs
1111

12+
import scala.annotation.tailrec
13+
1214
import java.io.{Console => _, _}
1315
import java.net._
1416

@@ -153,8 +155,17 @@ abstract class AbstractNodeJSEnv(
153155

154156
protected trait NodeComJSRunner extends ComJSRunner with JSInitFiles {
155157

158+
/* Manipulation of the socket must be protected by synchronized, except
159+
* calls to `close()`.
160+
*/
156161
private[this] val serverSocket =
157162
new ServerSocket(0, 0, InetAddress.getByName(null)) // Loopback address
163+
164+
/* Those 3 fields are assigned *once* under synchronization in
165+
* `awaitConnection()`.
166+
* Read access must be protected by synchronized, or be done after a
167+
* successful call to `awaitConnection()`.
168+
*/
158169
private var comSocket: Socket = _
159170
private var jvm2js: DataOutputStream = _
160171
private var js2jvm: DataInputStream = _
@@ -280,34 +291,67 @@ abstract class AbstractNodeJSEnv(
280291
}
281292

282293
def close(): Unit = {
294+
/* Close the socket first. This will cause any existing and upcoming
295+
* calls to `awaitConnection()` to be canceled and throw a
296+
* `SocketException` (unless it has already successfully completed the
297+
* `accept()` call).
298+
*/
283299
serverSocket.close()
284-
if (jvm2js != null)
285-
jvm2js.close()
286-
if (js2jvm != null)
287-
js2jvm.close()
288-
if (comSocket != null)
289-
comSocket.close()
300+
301+
/* Now wait for a possibly still-successful `awaitConnection()` to
302+
* complete before closing the sockets.
303+
*/
304+
synchronized {
305+
if (comSocket != null) {
306+
jvm2js.close()
307+
js2jvm.close()
308+
comSocket.close()
309+
}
310+
}
290311
}
291312

292313
/** Waits until the JS VM has established a connection or terminates
293314
*
294315
* @return true if the connection was established
295316
*/
296-
private def awaitConnection(): Boolean = {
297-
serverSocket.setSoTimeout(acceptTimeout)
298-
while (comSocket == null && isRunning) {
299-
try {
300-
comSocket = serverSocket.accept()
301-
jvm2js = new DataOutputStream(
302-
new BufferedOutputStream(comSocket.getOutputStream()))
303-
js2jvm = new DataInputStream(
304-
new BufferedInputStream(comSocket.getInputStream()))
305-
} catch {
306-
case to: SocketTimeoutException =>
317+
private def awaitConnection(): Boolean = synchronized {
318+
if (comSocket != null) {
319+
true
320+
} else {
321+
@tailrec
322+
def acceptLoop(): Option[Socket] = {
323+
if (!isRunning) {
324+
None
325+
} else {
326+
try {
327+
Some(serverSocket.accept())
328+
} catch {
329+
case to: SocketTimeoutException => acceptLoop()
330+
}
331+
}
307332
}
308-
}
309333

310-
comSocket != null
334+
serverSocket.setSoTimeout(acceptTimeout)
335+
val optComSocket = acceptLoop()
336+
337+
optComSocket.fold {
338+
false
339+
} { comSocket0 =>
340+
val jvm2js0 = new DataOutputStream(
341+
new BufferedOutputStream(comSocket0.getOutputStream()))
342+
val js2jvm0 = new DataInputStream(
343+
new BufferedInputStream(comSocket0.getInputStream()))
344+
345+
/* Assign those three fields together, without the possibility of
346+
* an exception happening in the middle (see #3408).
347+
*/
348+
comSocket = comSocket0
349+
jvm2js = jvm2js0
350+
js2jvm = js2jvm0
351+
352+
true
353+
}
354+
}
311355
}
312356

313357
override protected def finalize(): Unit = close()

0 commit comments

Comments
 (0)