Source code for mii.batching.data_classes

# Copyright (c) Microsoft Corporation.
# SPDX-License-Identifier: Apache-2.0

# DeepSpeed Team
from dataclasses import dataclass, field, asdict
from typing import Any, Dict, List, Iterator, Union
from typing_extensions import Self

import torch

from mii.constants import GenerationFinishReason
from mii.config import GenerateParamsConfig


[docs]@dataclass class Response: """ Response object returns from text-generation pipelines and persistent deployments. """ generated_text: str """ The generated text. """ prompt_length: int """ Number of tokens in the prompt. """ generated_length: int """ Number of generated tokens. """ finish_reason: GenerationFinishReason """ Reason for ending generation. One of :class:`mii.constants.GenerationFinishReason`. """ @staticmethod def from_msg_dict(msg: Dict[str, Union[str, int]]) -> Self: return Response(**msg) def to_msg_dict(self) -> Dict[str, Union[str, int]]: return asdict(self) def __repr__(self) -> str: return self.generated_text def __str__(self) -> str: return self.generated_text
@dataclass class RequestMsg: uid: int input_tokens: Union[torch.Tensor, List[int]] @property def is_flush_request(self): return self.input_tokens is None @staticmethod def from_msg_dict(msg: Dict[str, Any]) -> Self: input_tokens = msg["input_tokens"] if input_tokens is not None: input_tokens = torch.tensor(msg["input_tokens"], dtype=torch.int32, device=torch.device("cpu")) return RequestMsg(uid=msg["uid"], input_tokens=input_tokens) @dataclass class Request: tid: int uid: int input_tokens: torch.Tensor prompt_tokens: torch.Tensor seq_length: int last_in_prompt: bool post_processing: List[object] generate_params: GenerateParamsConfig _next_token: Union[None, torch.Tensor] = None _is_done: bool = False _generated_tokens: List[torch.Tensor] = field(default_factory=list) _finish_reason: GenerationFinishReason = GenerationFinishReason.NONE @property def prompt_length(self) -> int: return len(self.prompt_tokens) @property def next_token(self) -> Union[None, torch.Tensor]: return self._next_token @property def ignore_eos(self) -> bool: return self.generate_params.ignore_eos @property def min_new_tokens(self) -> int: return self.generate_params.min_new_tokens @property def max_new_tokens(self) -> int: return self.generate_params.max_new_tokens @max_new_tokens.setter def max_new_tokens(self, max_new_tokens: int) -> None: self.generate_params.max_new_tokens = max_new_tokens @property def stream(self) -> bool: return self.generate_params.stream @property def return_full_text(self) -> bool: return self.generate_params.return_full_text @property def max_length(self) -> int: return self.generate_params.max_length @next_token.setter def next_token(self, next_token: Union[None, torch.Tensor]) -> None: self._next_token = next_token @property def is_done(self) -> bool: if self.ignore_eos: return False if self.seq_length < self.min_new_tokens: return False return self._is_done @is_done.setter def is_done(self, is_done: bool) -> None: self._is_done = is_done @property def generated_tokens(self) -> List[torch.Tensor]: return self._generated_tokens @property def finish_reason(self) -> GenerationFinishReason: return self._finish_reason @property def is_flush_request(self): return self.input_tokens is None @property def num_generated_tokens(self) -> int: # We return zero while we are processing decomposed prompts return self.seq_length - self.prompt_length + 1 if self.seq_length >= self.prompt_length else 0 @property def stop_generation(self) -> bool: # Returns whether to stop generation for request if self.is_done: self._finish_reason = GenerationFinishReason.STOP return True if (self.seq_length >= self.max_length) or (self.num_generated_tokens >= self.max_new_tokens): self._finish_reason = GenerationFinishReason.LENGTH return True return False def to_msg_dict(self) -> Dict[str, Any]: # Returns a minimal version of the request of purposes of broadcasting to all ranks input_tokens = self.input_tokens if input_tokens is not None: input_tokens = self.input_tokens.tolist() return {"uid": self.uid, "input_tokens": input_tokens} def accumulate_generated_token(self) -> None: # Append the latest token to the list of generated tokens if not self.is_done: self._generated_tokens.append(self.next_token) def clear_generated_token(self) -> None: self._generated_tokens.clear() def set_next_as_input(self) -> None: # Places the next token into the input token for next round of generation if self.next_token is not None: self.input_tokens = self.next_token.unsqueeze(0) self.last_in_prompt = True self.next_token = None self.is_done = False class RequestBatch: def __init__(self, requests: List[Request] = None) -> None: if requests is None: requests = [] self.requests = requests def __len__(self) -> int: return len(self.requests) def __contains__(self, r: Request) -> bool: return r in self.requests def __nonzero__(self) -> bool: if len(self.requests) != 0: return True return False def __iter__(self) -> Iterator[Request]: return iter(self.requests) def __repr__(self) -> str: return f"RequestBatch({self.requests})" @property def requests_to_run(self) -> Self: return RequestBatch([r for r in self.requests if not r.is_flush_request]) @property def requests_to_flush(self) -> Self: return RequestBatch([r for r in self.requests if r.is_flush_request]) @property def last_in_prompt(self) -> Self: return RequestBatch([r for r in self.requests if r.last_in_prompt]) @property def completed(self) -> Self: return RequestBatch([r for r in self.requests if r.stop_generation]) @property def uids(self) -> List[int]: return [r.uid for r in self.requests] @property def lengths(self) -> List[int]: return [len(r.input_tokens) for r in self.requests] @property def tokens(self) -> List[torch.Tensor]: return [r.input_tokens for r in self.requests] @property def next_tokens(self) -> List[torch.Tensor]: return [r.next_token for r in self.requests] @property def done_tokens(self) -> List[bool]: return [r.is_done for r in self.requests] @next_tokens.setter def next_tokens(self, next_tokens: torch.Tensor) -> None: assert len(next_tokens) == len(self.requests) for idx, r in enumerate(self.requests): r.next_token = next_tokens[idx] @done_tokens.setter def done_tokens(self, done_tokens: torch.Tensor) -> None: assert len(done_tokens) == len(self.requests) for idx, r in enumerate(self.requests): r.is_done = done_tokens[idx].item() def to_msg_dicts(self) -> List[Dict[str, Any]]: return [r.to_msg_dict() for r in self.requests] @staticmethod def from_msg_dicts(msg_dicts: List[Dict[str, Any]]) -> Self: return RequestBatch([RequestMsg.from_msg_dict(msg) for msg in msg_dicts]) def prune(self, uids: List[int]) -> None: self.requests = [r for r in self.requests if r.uid not in uids] def append(self, r: Request) -> None: self.requests.append(r) def update_seq_length(self) -> None: for r in self.requests: r.seq_length += r.input_tokens.size(0)