# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import asyncio
import grpc
import requests
from typing import Dict, Any, Callable, List, Union
from mii.batching.data_classes import Response
from mii.config import MIIConfig
from mii.constants import GRPC_MAX_MSG_SIZE
from mii.grpc_related.proto import modelresponse_pb2, modelresponse_pb2_grpc
from mii.grpc_related.task_methods import TASK_METHODS_DICT
def create_channel(host, port):
return grpc.aio.insecure_channel(
f"{host}:{port}",
options=[
("grpc.max_send_message_length",
GRPC_MAX_MSG_SIZE),
("grpc.max_receive_message_length",
GRPC_MAX_MSG_SIZE),
],
)
[docs]class MIIClient:
"""
Client for sending generation requests to a persistent deployment created
with :func:`mii.serve`. Use :func:`mii.client` to create an instance of this
class.
:param mii_config: MII config for the persistent deployment to connect with.
:param host: hostname where the persistent deployment is running.
"""
def __init__(self, mii_config: MIIConfig, host: str = "localhost") -> None:
self.mii_config = mii_config
self.task = mii_config.model_conf.task
self.port = mii_config.port_number
self.asyncio_loop = asyncio.get_event_loop()
channel = create_channel(host, self.port)
# This stub allows interaction the client to send/receive messages with
# the load balancer process
self.stub = modelresponse_pb2_grpc.ModelResponseStub(channel)
[docs] def __call__(self, *args, **kwargs) -> List[Response]:
"""
All args and kwargs get passed directly to
:meth:`~mii.backend.client.MIIClient.generate`.
:return: A list of :class:`Response` objects containing the generated
text for all prompts.
"""
return self.generate(*args, **kwargs)
async def _request_async_response(self, prompts, **query_kwargs):
task_methods = TASK_METHODS_DICT[self.task]
proto_request = task_methods.pack_request_to_proto(prompts, **query_kwargs)
proto_response = await getattr(self.stub, task_methods.method)(proto_request)
return task_methods.unpack_response_from_proto(proto_response)
async def _request_async_response_stream(self, prompts, **query_kwargs):
task_methods = TASK_METHODS_DICT[self.task]
proto_request = task_methods.pack_request_to_proto(prompts, **query_kwargs)
assert hasattr(task_methods, "method_stream_out"), f"{self.task} does not support streaming response"
async for response in getattr(self.stub,
task_methods.method_stream_out)(proto_request):
yield task_methods.unpack_response_from_proto(response)
[docs] def generate(self,
prompts: Union[str,
List[str]],
streaming_fn: Callable = None,
**generate_kwargs: Dict) -> List[Response]:
"""
Generates text for the given prompts.
:param prompts: The string or list of strings used as prompts for generation.
:param streaming_fn: Streaming support is currently a WIP.
:param \\*\\*generate_kwargs: Generation keywords. A full list can be found here.
:return: A list of :class:`Response` objects containing the generated
text for all prompts.
""" # noqa: W605
if isinstance(prompts, str):
prompts = [prompts]
if streaming_fn is not None:
if len(prompts) > 1:
raise RuntimeError(
"MII client streaming only supports a single prompt input.")
generate_kwargs["stream"] = True
return self._generate_stream(streaming_fn, prompts, **generate_kwargs)
return self.asyncio_loop.run_until_complete(
self._request_async_response(prompts,
**generate_kwargs))
def _generate_stream(self,
callback,
prompts: List[str],
**query_kwargs: Dict[str,
Any]) -> None:
async def put_result():
response_stream = self._request_async_response_stream(
prompts,
**query_kwargs)
while True:
try:
response = await response_stream.__anext__()
callback(response)
except StopAsyncIteration:
break
self.asyncio_loop.run_until_complete(put_result())
async def terminate_async(self) -> None:
await self.stub.Terminate(
modelresponse_pb2.google_dot_protobuf_dot_empty__pb2.Empty())
[docs] def terminate_server(self) -> None:
"""
Terminates the persistent deployment server. This can be called from any
client.
"""
self.asyncio_loop.run_until_complete(self.terminate_async())
if self.mii_config.enable_restful_api:
requests.get(
f"http://localhost:{self.mii_config.restful_api_port}/terminate")