Skip to content

Commit 920cb04

Browse files
committed
Nicer handling of the end-of-stream ("[DONE]"_
1 parent f298ddc commit 920cb04

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

openai-client-stream/src/main/scala/io/cequence/openaiscala/service/OpenAIServiceStreamedImpl.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import io.cequence.openaiscala.domain.settings._
99
import io.cequence.openaiscala.domain.response._
1010
import io.cequence.openaiscala.service.ws.{Timeouts, WSStreamRequestHelper}
1111
import io.cequence.openaiscala.OpenAIScalaClientException
12-
import play.api.libs.json.JsValue
12+
import play.api.libs.json.{JsValue, Json}
1313

1414
import scala.concurrent.ExecutionContext
1515

openai-client-stream/src/main/scala/io/cequence/openaiscala/service/ws/WSStreamRequestHelper.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import akka.stream.scaladsl.{Flow, Framing, Source}
99
import akka.util.ByteString
1010
import com.fasterxml.jackson.core.JsonParseException
1111
import io.cequence.openaiscala.{OpenAIScalaClientException, OpenAIScalaClientTimeoutException, OpenAIScalaClientUnknownHostException}
12-
import play.api.libs.json.{JsNull, JsObject, JsValue, Json}
12+
import play.api.libs.json.{JsNull, JsObject, JsString, JsValue, Json}
1313
import play.api.libs.ws.JsonBodyWritables._
1414

1515
import java.net.UnknownHostException
@@ -32,7 +32,7 @@ trait WSStreamRequestHelper {
3232
private implicit val jsonMarshaller: Unmarshaller[ByteString, JsValue] =
3333
Unmarshaller.strict[ByteString, JsValue] { byteString =>
3434
val data = byteString.utf8String.stripPrefix(itemPrefix)
35-
if (data.equals(endOfStreamToken)) JsNull else Json.parse(data)
35+
if (data.equals(endOfStreamToken)) JsString(endOfStreamToken) else Json.parse(data)
3636
}
3737

3838
protected def execJsonStreamAux(
@@ -56,8 +56,8 @@ trait WSStreamRequestHelper {
5656
}
5757
)
5858

59-
// filter the end of stream marked with JsNull
60-
source.filter(_ != JsNull)
59+
// take until you encounter the end of stream marked with DONE
60+
source.takeWhile(_ != JsString(endOfStreamToken))
6161
}
6262

6363
protected def execStreamRequestAux[T](

0 commit comments

Comments
 (0)