Skip to content

Commit 2f3e266

Browse files
authored
fix async client for transformers chat (#41255)
* fix-client * fix
1 parent 313504b commit 2f3e266

File tree

1 file changed

+38
-40
lines changed

1 file changed

+38
-40
lines changed

src/transformers/commands/chat.py

Lines changed: 38 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -687,7 +687,6 @@ async def _inner_run(self):
687687

688688
model = self.args.model_name_or_path + "@" + self.args.model_revision
689689
host = "http://localhost" if self.args.host == "localhost" else self.args.host
690-
client = AsyncInferenceClient(f"{host}:{self.args.port}")
691690

692691
args = self.args
693692
if args.examples_path is None:
@@ -710,48 +709,47 @@ async def _inner_run(self):
710709

711710
# Starts the session with a minimal help message at the top, so that a user doesn't get stuck
712711
interface.print_help(minimal=True)
713-
while True:
714-
try:
715-
user_input = interface.input()
716-
717-
# User commands
718-
if user_input.startswith("!"):
719-
# `!exit` is special, it breaks the loop
720-
if user_input == "!exit":
721-
break
722-
else:
723-
chat, valid_command, generation_config, model_kwargs = self.handle_non_exit_user_commands(
724-
user_input=user_input,
725-
args=args,
726-
interface=interface,
727-
examples=examples,
728-
generation_config=generation_config,
729-
model_kwargs=model_kwargs,
730-
chat=chat,
731-
)
732-
# `!example` sends a user message to the model
733-
if not valid_command or not user_input.startswith("!example"):
734-
continue
735-
else:
736-
chat.append({"role": "user", "content": user_input})
737-
738-
stream = client.chat_completion(
739-
chat,
740-
stream=True,
741-
extra_body={
742-
"generation_config": generation_config.to_json_string(),
743-
"model": model,
744-
},
745-
)
746712

747-
model_output = await interface.stream_output(stream)
713+
async with AsyncInferenceClient(f"{host}:{self.args.port}") as client:
714+
while True:
715+
try:
716+
user_input = interface.input()
717+
718+
# User commands
719+
if user_input.startswith("!"):
720+
# `!exit` is special, it breaks the loop
721+
if user_input == "!exit":
722+
break
723+
else:
724+
chat, valid_command, generation_config, model_kwargs = self.handle_non_exit_user_commands(
725+
user_input=user_input,
726+
args=args,
727+
interface=interface,
728+
examples=examples,
729+
generation_config=generation_config,
730+
model_kwargs=model_kwargs,
731+
chat=chat,
732+
)
733+
# `!example` sends a user message to the model
734+
if not valid_command or not user_input.startswith("!example"):
735+
continue
736+
else:
737+
chat.append({"role": "user", "content": user_input})
738+
739+
stream = client.chat_completion(
740+
chat,
741+
stream=True,
742+
extra_body={
743+
"generation_config": generation_config.to_json_string(),
744+
"model": model,
745+
},
746+
)
748747

749-
chat.append({"role": "assistant", "content": model_output})
748+
model_output = await interface.stream_output(stream)
750749

751-
except KeyboardInterrupt:
752-
break
753-
finally:
754-
await client.close()
750+
chat.append({"role": "assistant", "content": model_output})
751+
except KeyboardInterrupt:
752+
break
755753

756754

757755
if __name__ == "__main__":

0 commit comments

Comments
 (0)