# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0
# DeepSpeed Team
import copy
import gc
import os
import queue
import random
import threading
import time
from dataclasses import dataclass
from collections import deque, defaultdict
from functools import cached_property
from typing import Dict, Tuple, List, Any, Union, DefaultDict
import torch
import ujson
import zmq
from deepspeed.accelerator import get_accelerator
from deepspeed.utils.timer import SynchronizedWallClockTimer
from mii.batching.constants import TOP_K_NAME, TOP_P_NAME, TEMP_NAME, SAMPLER_NAME, STOP_NAME
from mii.batching.data_classes import Response, Request, RequestBatch
from mii.batching.generation.logit_processors import TopPLogitProcessor, TopKLogitProcessor, TemperatureLogitProcessor
from mii.batching.generation.samplers import LogitsSampler, GreedySampler
from mii.batching.generation.stop_criterion import EosGenerationStopCriterion, TokenStopCriterion
from mii.batching.postprocess import (
run_batch_logit_processing,
run_batch_sampler,
run_batch_stop_criterion,
)
from mii.batching.utils import sync_debug, profiler
from mii.config import GenerateParamsConfig
from mii.constants import GenerationFinishReason, ZMQ_RECV_TIMEOUT
from mii.logging import logger
from mii.modeling.tokenizers import MIITokenizerWrapper
class RaggedBatchBase:
def __init__(self, inference_engine, tokenizer, model_config):
self.inference_engine = inference_engine
self.tokenizer = tokenizer
self.vocab_size = tokenizer.vocab_size
self.model_config = model_config
self.zmq_port = model_config.zmq_port_number
# Set max sequence length from either user-passed model_config or from
# HF model_config
if model_config.max_length is not None:
self.max_length = model_config.max_length
else:
self.max_length = inference_engine._policy._checkpoint_engine.model_config.max_seq_length
self.sync_debug = model_config.sync_debug
self.profile_model_time = model_config.profile_model_time
# Create queues and other values for scheduling of requests and results
self.request_queue: queue.Queue = queue.Queue()
self.result_queues: Dict[int, queue.Queue] = {}
self.scheduled_requests: RequestBatch = RequestBatch()
self.buffer = deque()
self.scheduled_length = 0
self.scheduled_seq_num = 0
self.scheduled_req_blocks = 0
# TODO: Each request we process can have a unique post_processor (e.g.,
# different temperature value). We will need to prune
# self._post_processors for long running deployments
self._post_processors = {}
self.logit_processor = run_batch_logit_processing
self.sampler = run_batch_sampler
self.stop_criterion = run_batch_stop_criterion
# If profiling is enabled, these are used to capture/generate data
self._timers: SynchronizedWallClockTimer = SynchronizedWallClockTimer()
self._profiled_times: DefaultDict[str, List[int]] = defaultdict(list)
self._iters: int = 0
self._num_generated_tokens: int = 0
# Use ZMQ because it is light-weight and fast for passing simple
# messages (i.e., token sequences) between each TP process of the
# inference engine
self._zmq_context = zmq.Context()
torch.cuda.synchronize()
if self.is_rank_0:
self.socket = self._zmq_context.socket(zmq.PUB)
self.socket.bind(f"tcp://*:{self.zmq_port}")
time.sleep(1) # Give the subscriber a chance to connect
else:
self.socket = self._zmq_context.socket(zmq.SUB)
self.socket.connect(f"tcp://localhost:{self.zmq_port}")
self.socket.setsockopt_string(zmq.SUBSCRIBE, "")
self.socket.setsockopt(zmq.RCVTIMEO, ZMQ_RECV_TIMEOUT)
@cached_property
def local_rank(self) -> int:
return get_accelerator().current_device()
@property
def is_rank_0(self) -> bool:
return self.local_rank == 0
@profiler
def generate(self) -> Union[None, bool]:
"""
This is the main loop of FastGen: puts requests and gets generated results.
"""
# 1. Get a batch of requests, broadcast to all ranks
scheduled_requests, force = self._bcast_requests()
# 2. Flush for uids that are finished generating
self.flush(scheduled_requests.requests_to_flush.uids)
# 3. Put new tokens into inference engine
if scheduled_requests.requests_to_run:
next_token_logits = self.put(
scheduled_requests.requests_to_run.uids,
scheduled_requests.requests_to_run.tokens,
)
# short circuit if not rank 0, only rank 0 does scheduling and postprocessing of logits
if not self.is_rank_0:
return force
# 4. Launch logit processing and token generation
running_requests = scheduled_requests.requests_to_run
running_requests.update_seq_length()
if running_requests:
next_tokens, done_tokens = self._process_logits(
next_token_logits, running_requests
)
running_requests.next_tokens = next_tokens
running_requests.done_tokens = done_tokens
# 5. Schedule requests while we wait for the forward pass to finish
self._reset_scheduler_bookkeeping()
# 6. Accumulate generated tokens, check completion, and generate output
for r in running_requests.last_in_prompt:
r.accumulate_generated_token()
self._num_generated_tokens += 1
if r.stop_generation or r.stream:
self._generate_output(r)
if not r.stop_generation:
r.set_next_as_input()
self.request_queue.put(r)
# 7. Update scheduled requests
self.scheduled_requests.prune(running_requests.completed.uids)
self.schedule_requests()
if self.profile_model_time:
self._print_profiled_times()
def _print_profiled_times(self) -> None:
self._iters += 1
if not (self._iters % 100 == 0):
return
for event, times in self._profiled_times.items():
mean_time = sum(times) / len(times)
log_msg = f"{event}: {mean_time}"
if event == "generate":
log_msg += f" ({self._num_generated_tokens / sum(times)} tokens/ms)"
logger.info(log_msg)
self._profiled_times.clear()
self._num_generated_tokens = 0
@sync_debug
def _bcast_requests(self, force=False) -> RequestBatch:
# Rank 0 is the main process that does scheduling of requests on the
# inference engine. When new requests are to be placed on the engine,
# the prompt tokens must be broadcast to all TP processes.
if self.is_rank_0:
if not self.scheduled_requests and not force:
return self.scheduled_requests, force
# Rank 0 gets batch of requests and broadcasts to other ranks
data_dicts = self.scheduled_requests.to_msg_dicts()
json_data = ujson.dumps({"data": data_dicts, "force": force})
self.socket.send_string(json_data)
else:
try:
json_data = self.socket.recv_string()
recv_dict = ujson.loads(json_data)
data_dicts = recv_dict["data"]
force = recv_dict["force"]
self.scheduled_requests = RequestBatch.from_msg_dicts(data_dicts)
except zmq.Again:
self.scheduled_requests = RequestBatch()
return self.scheduled_requests, force
def _reset_scheduler_bookkeeping(self) -> None:
self.scheduled_requests = RequestBatch()
self.scheduled_length = 0
self.scheduled_seq_num = 0
self.scheduled_req_blocks = 0
@sync_debug
def _process_logits(
self,
next_token_logits: torch.Tensor,
running_requests: RequestBatch) -> Tuple[torch.Tensor,
torch.Tensor]:
# Process generated logits, run post processing, gets next token, and
# checks for stop criteria at each round of generation for all requests.
next_token_logits = next_token_logits[:, :self.vocab_size]
next_token_logits = self.logit_processor(next_token_logits,
running_requests,
self._post_processors)
next_tokens = self.sampler(next_token_logits,
running_requests,
self._post_processors)
done_tokens = self.stop_criterion(next_tokens,
running_requests,
self._post_processors)
next_tokens = next_tokens.to(torch.device("cpu"), non_blocking=False)
done_tokens = done_tokens.to(torch.device("cpu"), non_blocking=False)
return next_tokens, done_tokens
@sync_debug
def _generate_output(self, r: Request) -> bool:
# Gather generated tokens and put them in the result queue. For
# streaming, this happens at every generated token. For non-streaming,
# this happens only when a stop criteria is met.
outputs = []
if r.stream:
outputs.append((
r.uid,
[r.next_token],
r.prompt_length,
r.num_generated_tokens,
GenerationFinishReason.NONE,
r.stream,
))
if r.finish_reason != GenerationFinishReason.NONE:
if r.stream or not r.generated_tokens:
output_tokens = []
else:
output_tokens = torch.cat([t.unsqueeze(0) for t in r.generated_tokens],
dim=0)
if r.return_full_text:
# Avoid returning bos token, refactor this later
output_tokens = torch.cat((r.prompt_tokens[1:], output_tokens))
outputs.append((
r.uid,
output_tokens,
r.prompt_length,
r.num_generated_tokens,
r.finish_reason,
r.stream,
))
for output in outputs:
self.result_queues[r.tid].put_nowait(output)
def _schedule_token_gen(self, requests: List[Request]) -> None:
free_blocks = min(self.inference_engine.free_blocks)
conf_manager = self.inference_engine._config.state_manager
num_schedulable = min([
len(requests),
conf_manager.max_ragged_sequence_count,
conf_manager.max_ragged_batch_size
])
for r in requests[:num_schedulable]:
block_capacity = self.inference_engine.get_remaining_block_capacity(r.uid)
# We can schedule token generation if the last block has a capacity
if block_capacity > 0:
self.scheduled_length += 1
self.scheduled_requests.append(r)
elif free_blocks > 0:
# We need a new block
free_blocks -= 1
self.scheduled_length += 1
self.scheduled_req_blocks += 1
self.scheduled_requests.append(r)
def _schedule_prompts(self, requests: List[Request]) -> None:
free_blocks = min(self.inference_engine.free_blocks)
conf_manager = self.inference_engine._config.state_manager
for r in requests:
if free_blocks == 0:
break
if r.max_length <= r.seq_length:
continue
# Make sure that the engine has enough capacity to process the batch
if len(self.scheduled_requests.requests_to_run
) >= conf_manager.max_ragged_sequence_count:
break
max_batch_size = conf_manager.max_ragged_batch_size - self.scheduled_length
if max_batch_size <= 0:
break
max_blocks = free_blocks - self.scheduled_req_blocks
if len(r.input_tokens) > 1:
# When the KV cache is out of capacity, we release KV cache blocks for a request.
# However, we can immediately schedule the request again if we split the request.
# So we make sure that we have capacity for the entire prompt (+tokens already generated).
req_tokens, _ = self.inference_engine.query(r.uid, len(r.input_tokens), max_blocks)
if req_tokens < len(r.input_tokens):
break
req_tokens = min(len(r.input_tokens), max_batch_size)
req_tokens, req_blocks = self.inference_engine.query(r.uid, req_tokens, max_blocks)
if req_tokens <= 0:
continue
# Decompose the prompt to fit to the max ragged batch size
decomposed = req_tokens < len(r.input_tokens)
remaining_tokens = r.input_tokens[req_tokens:]
r.input_tokens = r.input_tokens[:req_tokens]
r.last_in_prompt = not decomposed
# Schedule the request
self.scheduled_requests.append(r)
self.scheduled_req_blocks += req_blocks
self.scheduled_length += req_tokens
if decomposed:
req_remaining = copy.copy(r)
req_remaining.input_tokens = remaining_tokens
req_remaining.seq_length = r.seq_length + req_tokens
req_remaining.last_in_prompt = True
self.buffer.appendleft(req_remaining)
def schedule_requests(self) -> None:
while not self.request_queue.empty():
r = self.request_queue.get_nowait()
self.buffer.append(r)
next_token_gen_reqs = []
prompt_reqs = []
for r in self.buffer:
if r.is_flush_request:
self.scheduled_requests.append(r)
else:
if r.num_generated_tokens > 0:
if r.max_length > r.seq_length:
next_token_gen_reqs.append(r)
else:
prompt_reqs.append(r)
# We want to process next token generation first
self._schedule_token_gen(next_token_gen_reqs)
self._schedule_prompts(prompt_reqs)
if len(self.buffer) > 0 and len(self.scheduled_requests) == 0:
self.scheduled_requests = RequestBatch()
self.reset_request_status()
else:
scheduled_requests_ids = set(id(r) for r in self.scheduled_requests)
self.buffer = deque(
[r for r in self.buffer if id(r) not in scheduled_requests_ids])
def _queue_flush_request(self, uid: int) -> None:
self.request_queue.put_nowait(
Request(
tid=None,
uid=uid,
input_tokens=None,
prompt_tokens=None,
seq_length=None,
last_in_prompt=None,
post_processing=None,
generate_params=None,
))
def reset_request_status(self):
## Get the last request that consumes KV cache
last_r = None
for r in self.buffer:
if r.seq_length > 0:
last_r = r
assert last_r is not None, "Function to clear the KV cache is invoked, but no request consumes KV cache"
## Schedule flushing r
self.scheduled_requests.append(
Request(
tid=None,
uid=last_r.uid,
input_tokens=None,
prompt_tokens=None,
seq_length=None,
last_in_prompt=None,
post_processing=None,
generate_params=None,
))
## Rebuild the request
new_req = copy.copy(last_r)
new_req.prompt_tokens = new_req.input_tokens = torch.concat(
[last_r.prompt_tokens] + [t.unsqueeze(0) for t in last_r.generated_tokens])
new_req.seq_length = 0
new_req.max_new_tokens = last_r.max_new_tokens - len(last_r.generated_tokens)
new_req.clear_generated_token()
## Remove the requests from buffer and queue
new_buffer = deque()
for r in self.buffer:
if r.uid != last_r.uid:
new_buffer.append(r)
while not self.request_queue.empty():
r = self.request_queue.get_nowait()
if r.uid != last_r.uid:
new_buffer.append(r)
new_buffer.append(new_req)
self.buffer = new_buffer
def make_request(self,
tid: int,
uid: int,
input_tokens: torch.Tensor,
kwargs: Dict) -> Request:
kwargs["prompt_length"] = len(input_tokens)
kwargs["max_length"] = kwargs.get("max_length", self.max_length)
generate_params = GenerateParamsConfig(**kwargs)
post_processing = []
top_p = generate_params.top_p
top_p_name = "_".join((TOP_P_NAME, str(top_p)))
if top_p_name not in self._post_processors:
self._post_processors[top_p_name] = TopPLogitProcessor(top_p=top_p)
post_processing.append(top_p_name)
top_k = generate_params.top_k
if top_k is not None:
top_k_name = "_".join((TOP_K_NAME, str(top_k)))
if top_k_name not in self._post_processors:
self._post_processors[top_k_name] = TopKLogitProcessor(top_k=top_k)
post_processing.append(top_k_name)
temp = generate_params.temperature
if temp is not None:
temp_name = "_".join((TEMP_NAME, str(temp)))
if temp_name not in self._post_processors:
self._post_processors[temp_name] = TemperatureLogitProcessor(
temperature=temp)
post_processing.append(temp_name)
do_sample = generate_params.do_sample
if do_sample:
sampler_name = "_".join((SAMPLER_NAME, "logits"))
if sampler_name not in self._post_processors:
self._post_processors[sampler_name] = LogitsSampler()
else:
sampler_name = "_".join((SAMPLER_NAME, "greedy"))
if sampler_name not in self._post_processors:
self._post_processors[sampler_name] = GreedySampler()
post_processing.append(sampler_name)
stop = generate_params.stop
if stop != []:
for each_stop in stop:
stop_name = STOP_NAME + '_' + each_stop
if stop_name not in self._post_processors:
self._post_processors[stop_name] = TokenStopCriterion(
token=each_stop,
tokenizer=self.tokenizer)
post_processing.append(stop_name)
else:
stop_name = STOP_NAME
if STOP_NAME not in self._post_processors:
self._post_processors[stop_name] = EosGenerationStopCriterion(
tokenizer=self.tokenizer)
post_processing.append(stop_name)
return Request(
tid=tid,
uid=uid,
input_tokens=input_tokens,
prompt_tokens=input_tokens,
seq_length=0,
last_in_prompt=True,
post_processing=post_processing,
generate_params=generate_params,
)
def make_response(self,
generated_text: str,
prompt_length: int,
generated_length: int,
finish_reason: GenerationFinishReason) -> Response:
return Response(generated_text=generated_text,
prompt_length=prompt_length,
generated_length=generated_length,
finish_reason=finish_reason)
def put(self, uids: List[int], tokenized_input: List[torch.Tensor]) -> torch.Tensor:
# Call inference engine. You can skip checking schedulability because we already checked when scheduling
return self.inference_engine.put(uids, tokenized_input, do_checks=False)
def flush(self, uids: List[int]) -> None:
for uid in uids:
self.inference_engine.flush(uid)
@dataclass
class StreamState:
prev_token_size: int
token_ids: List[int]
class ReadableStream:
def __init__(self, tokenizer: MIITokenizerWrapper) -> None:
self.tokenizer = tokenizer
self.stream_state: Dict[int, StreamState] = {}
def init_state(self, thread_id: int) -> StreamState:
if thread_id not in self.stream_state:
self.stream_state[thread_id] = StreamState(token_ids=[], prev_token_size=0)
return self.stream_state[thread_id]
return self.stream_state[thread_id]
def flush_state(self, thread_id: int) -> None:
if thread_id in self.stream_state:
del self.stream_state[thread_id]
def decode(self, thread_id: int, token_ids: List[int]) -> str:
state = self.init_state(thread_id)
output = []
for token_id in token_ids:
state.token_ids.append(token_id)
decoded = self.tokenizer.decode(state.token_ids)
# We don't have enough token_ids in the buffer and
# tokenizer returned unicode 'U+FFFD REPLACEMENT CHARACTER'
if "\ufffd" in decoded:
continue
if state.prev_token_size > 0:
prev_token = state.token_ids[:state.prev_token_size]
state.token_ids = state.token_ids[state.prev_token_size:]
decoded = decoded.replace(self.tokenizer.decode(prev_token), "", 1)
output.append(decoded)
state.prev_token_size = len(state.token_ids)
return "".join(output)
[docs]class MIIPipeline(RaggedBatchBase):
"""
Pipeline class that inherits from :class:`RaggedBatchBase` and provides
functionality of ragged batching and dynamic splitfuse. This class is
returned from :func:`mii.pipeline`.
"""
def __init__(self, all_rank_output: bool = False, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.tid = threading.get_ident()
self._all_rank_output = all_rank_output
self._destroyed = False
get_accelerator().set_device(int(os.getenv("LOCAL_RANK", "0")))
[docs] def __call__(self,
prompts: Union[str,
List[str]],
**generate_kwargs) -> List[Response]:
"""
Generates text for the given prompts
:param prompts: The string or list of strings used as prompts for generation.
:param \\*\\*generate_kwargs: Generation keywords. A full list can be found
in :class:`GenerateParamsConfig <mii.config.GenerateParamsConfig>`.
:return: A list of :class:`Response` objects containing the generated
text for all prompts.
""" # noqa: W605
if self._destroyed:
raise RuntimeError(
"The inference engine of this pipeline has been destroyed.")
if isinstance(prompts, str):
prompts = [prompts]
outputs: List[Response] = []
uids_running: List[int] = list(range(len(prompts)))
uids_complete_order: List[int] = []
for uid, input in zip(uids_running, prompts):
request_kwargs = generate_kwargs.copy()
self._put_request(uid, input, request_kwargs)
if self.is_rank_0:
# Rank 0 runs generate() until all responses are returned
while uids_running:
while not self.result_queues[self.tid].empty():
uid, response = self._get_response()
outputs.append(response)
self._queue_flush_request(uid)
uids_complete_order.append(uid)
uids_running.remove(uid)
self.generate()
# Ensure final flush requests broadcast and
# kick ranks 1 -> n out of the while loop
self._bcast_requests(force=True)
self.flush(self.scheduled_requests.requests_to_flush.uids)
self.scheduled_requests = RequestBatch()
else:
# Ranks 1 -> n just run generate() until there are no more requests
exit = False
while not exit:
exit = self.generate()
outputs = [
r for idx,
r in sorted(zip(uids_complete_order,
outputs),
key=lambda pair: pair[0])
]
if self._all_rank_output:
outputs = self._bcast_responses(outputs)
return outputs
def _put_request(self, uid: int, input: str, kwargs: Dict[str, Any]) -> None:
self.result_queues[self.tid] = queue.Queue()
input_tokens = self.tokenizer.encode(input)
request = self.make_request(self.tid, uid, input_tokens, kwargs)
self.request_queue.put(request)
def _get_response(self) -> Tuple[int, Response]:
result = self.result_queues[self.tid].get()
uid = result[0]
generated_tokens = self.tokenizer.decode(result[1])
response = self.make_response(generated_tokens, result[2], result[3], result[4])
return uid, response
def _bcast_responses(self, responses: List[Response]) -> List[Response]:
if self.is_rank_0:
data_dicts = [r.to_msg_dict() for r in responses]
json_data = ujson.dumps(data_dicts)
self.socket.send_string(json_data)
else:
json_data = self.socket.recv_string()
data_dicts = ujson.loads(json_data)
responses = [Response.from_msg_dict(msg) for msg in data_dicts]
return responses
def destroy(self) -> None:
del self.inference_engine
self.socket.close()
self._zmq_context.term()
gc.collect()
get_accelerator().empty_cache()
self._destroyed = True
class MIIAsyncPipeline(RaggedBatchBase):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.uids = set()
self.lock = threading.Lock()
self.thread = None
self.stop_thread = False
self._is_shutdown = False
self.UID_RANGE_LB = 1
self.UID_RANGE_UB = 10000
self.readable_stream = ReadableStream(self.tokenizer)
def __call__(self) -> None:
# CUDA device gets reset, must set it again to avoid problems
get_accelerator().set_device(int(os.getenv("LOCAL_RANK", "0")))
while True:
self.generate()
if (self.stop_thread and self.request_queue.empty()
and all(q.empty() for q in self.result_queues.values())):
break
def _get_uid(self) -> int:
with self.lock:
if len(self.uids) >= self.UID_RANGE_UB - self.UID_RANGE_LB:
raise RuntimeError("No available choices for a new UID.")
uid = random.randrange(self.UID_RANGE_LB, self.UID_RANGE_UB)
while uid in self.uids:
uid = random.randrange(self.UID_RANGE_LB, self.UID_RANGE_UB)
self.uids.add(uid)
return uid
def put_request(self, prompt: str, kwargs: Dict) -> int:
# TODO: We should avoid any request/response work with non-rank 0, but
# this requires some refactoring how we do the put and request in
# `ModelResponse`
#if not self.is_rank_0:
# return
if self.stop_thread:
raise RuntimeError("The request queue was shutdown.")
uid = self._get_uid()
try:
# Temporary hack to avoid non-rank 0 processes not shutting down. See
# related TODO above.
if not self.is_rank_0:
return uid
tid = threading.get_ident()
with self.lock:
if tid not in self.result_queues:
self.result_queues[tid] = queue.Queue()
input_tokens = self.tokenizer.encode(prompt)
request = self.make_request(tid, uid, input_tokens, kwargs)
self.request_queue.put(request)
return uid
except:
# It is OK to have `self.request_queue.put(request)` in the try block since
# it will never raise exceptions with unlimited queue size. If any exception
# occurred in the above block, the `request` obj was not enqueued.
self.flush_uid(uid)
raise
def get_response(self) -> Tuple[int, Response]:
# TODO: We should avoid any request/response work with non-rank 0, but
# this requires some refactoring how we do the put and request in
# `ModelResponse`
if not self.is_rank_0:
return -1, Response(generated_text="",
prompt_length=None,
generated_length=None,
finish_reason=None)
tid = threading.get_ident()
uid, generated_token_ids, prompt_length, generated_length, finish_reason, streaming = self.result_queues[tid].get()
if len(generated_token_ids) == 0:
generated_text = ""
self.readable_stream.flush_state(tid)
elif streaming:
generated_text = self.readable_stream.decode(tid, generated_token_ids)
else:
generated_text = self.tokenizer.decode(generated_token_ids)
response = self.make_response(
generated_text=generated_text,
prompt_length=prompt_length,
generated_length=generated_length,
finish_reason=finish_reason,
)
return uid, response
def start(self) -> None:
self.thread = threading.Thread(target=self, daemon=True)
self.thread.start()
def shutdown(self) -> None:
self.stop_thread = True
self.thread.join()
self._is_shutdown = True
def is_shutdown(self) -> bool:
return self._is_shutdown
def flush_uid(self, uid: int) -> None:
with self.lock:
if self.is_rank_0:
self._queue_flush_request(uid)
self.uids.remove(uid)