Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Forward arbitrary kwargs to remote blocks #467

Open
wants to merge 40 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 39 commits
Commits
Show all changes
40 commits
Select commit Hold shift + click to select a range
1879788
typos
Aug 16, 2023
f313730
WIP
Aug 17, 2023
ed8d7f4
mwp
Aug 17, 2023
fb9b211
black-isort
Aug 17, 2023
355c150
black-isort
Aug 17, 2023
084d565
priority pool
Aug 17, 2023
65e8739
wip (again)
Aug 22, 2023
13c13d3
wip (again)
Aug 22, 2023
4529471
black, isort
Aug 22, 2023
d51c08e
undo debug change
Aug 22, 2023
1e5df29
Merge branch 'main' into forward_kwargs
justheuristic Aug 22, 2023
22bcbb3
Merge branch 'main' into forward_kwargs
justheuristic Aug 24, 2023
09e9da6
serialize outputs structure
Aug 25, 2023
84ebd57
WIP, switching to another PR
Aug 28, 2023
49ff759
undo
Aug 28, 2023
ce89b64
Merge branch 'main' into forward_kwargs
justheuristic Sep 1, 2023
6256995
Merge remote-tracking branch 'origin/main' into forward_kwargs
Sep 5, 2023
6c7f762
rollback: only generic kwarg
Sep 5, 2023
cc4fe17
minimize diff
Sep 5, 2023
2e76031
add docstr
Sep 5, 2023
e5c2d8e
WIP BEFORE MEETING NEED BACKWARD UPDATE
Sep 5, 2023
49474e5
wip some more
Sep 5, 2023
4393d99
1isort
Sep 5, 2023
465fd93
more WIP
Sep 5, 2023
f204965
make it work for fwd, bwd
Sep 5, 2023
b7bd477
black-isort
Sep 5, 2023
9e29140
mention reference issue
Sep 5, 2023
17d278e
black-isort-clarify
Sep 5, 2023
62e780c
check num block kwargs
Sep 6, 2023
aacd8b2
pass args/kwargs via forward
Sep 6, 2023
056cd77
standardize checking block_kwargs
Sep 6, 2023
a23bd73
probably break everyting
Sep 6, 2023
68b8cea
note
Sep 6, 2023
8eb1722
standardize: s/backend_kwargs/block_kwargs/g everywhere
Sep 6, 2023
3bffcde
black+isort
Sep 6, 2023
721f7d2
unbreak everything
Sep 6, 2023
3048c3b
rollback
Sep 6, 2023
3f06b53
temporary rollback: allow kwargs only at first inference step
Sep 6, 2023
c665c42
reduce diff
Sep 6, 2023
3195579
Merge remote-tracking branch 'origin/main' into forward_kwargs
justheuristic Dec 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/petals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@
assert (
version.parse("4.32.0") <= version.parse(transformers.__version__) < version.parse("5.0.0")
), "Please install a proper transformers version: pip install transformers>=4.32.0,<5.0.0"
assert version.parse("1.1.10") <= version.parse(
hivemind.__version__
), "Please install a proper hivemind version: pip install hivemind>=1.1.10"


def _override_bfloat16_mode_default():
Expand Down
139 changes: 81 additions & 58 deletions src/petals/client/inference_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,16 @@
import itertools
import time
import uuid
from typing import AsyncIterator, List, Optional, Tuple
from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple

import torch
from hivemind import MSGPackSerializer, anext, deserialize_torch_tensor, get_logger, serialize_torch_tensor
from hivemind.compression import deserialize_torch_tensor, serialize_torch_tensor
from hivemind.moe.client.remote_expert_worker import RemoteExpertWorker
from hivemind.p2p import P2P
from hivemind.proto import runtime_pb2
from hivemind.utils.tensor_descr import BatchTensorDescriptor
from hivemind.utils import MSGPackSerializer, anext, get_logger, nested_flatten

from petals.client.config import ClientConfig
from petals.client.routing import RemoteSequenceManager, maybe_log_traceback
from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo, RPCInfo
from petals.data_structures import CHAIN_DELIMITER, ModuleUID, RemoteSpanInfo
from petals.server.handler import TransformerConnectionHandler
from petals.utils.misc import DUMMY, DUMMY_INT64, is_dummy
from petals.utils.packaging import pack_args_kwargs
Expand All @@ -32,48 +30,48 @@ class _ServerInferenceSession:

def __init__(
self,
config: ClientConfig,
sequence_manager: RemoteSequenceManager,
span: RemoteSpanInfo,
uid: ModuleUID,
rpc_info: RPCInfo,
span_uids: Sequence[ModuleUID],
inputs_queue: asyncio.Queue,
outputs_aiter: AsyncIterator,
*,
*block_kwargs,
max_length: int,
**metadata,
):
self.config = config
self.span, self.uid, self.rpc_info = span, uid, rpc_info
self.num_blocks = uid.count(CHAIN_DELIMITER) + 1
self.sequence_manager = sequence_manager
self.span, self.span_uids = span, span_uids
self.num_blocks = len(span_uids)
self._inputs_queue: asyncio.Queue[runtime_pb2.ExpertRequest] = inputs_queue
self._outputs_stream: AsyncIterator[runtime_pb2.ExpertResponse] = outputs_aiter
self.session_id = str(uuid.uuid4())
self.session_metadata = dict(max_length=max_length, **metadata)
self.max_length = max_length
self.stepped = False
self.closed = False

self._position = 0
self.history = None # Used in case of server failures to regenerate attention caches on new servers
self.next_session = None

self.block_kwargs = block_kwargs
assert len(self.block_kwargs) in (0, self.num_blocks)

@classmethod
async def create(
cls,
config: ClientConfig,
p2p: P2P,
sequence_manager: RemoteSequenceManager,
span: RemoteSpanInfo,
uid: ModuleUID,
rpc_info: RPCInfo,
**metadata,
span_uids: Sequence[ModuleUID],
*block_kwargs: Dict[str, Any],
**kwargs,
) -> _ServerInferenceSession:
"""Create a new session for a given remote module. This code is meant to be run inside RemoteExpertWorker"""
stub = TransformerConnectionHandler.get_stub(p2p, span.peer_id)
stub = TransformerConnectionHandler.get_stub(sequence_manager.state.p2p, span.peer_id)
inputs_queue = asyncio.Queue()
outputs_stream = await asyncio.wait_for(
stub.rpc_inference(cls._read_inputs_from_queue(inputs_queue)),
config.connect_timeout,
sequence_manager.config.connect_timeout,
)
return cls(config, span, uid, rpc_info, inputs_queue, outputs_stream, **metadata)
return cls(sequence_manager, span, span_uids, inputs_queue, outputs_stream, *block_kwargs, **kwargs)

@staticmethod
async def _read_inputs_from_queue(queue: asyncio.Queue, input_timeout: Optional[float] = None) -> AsyncIterator:
Expand All @@ -87,13 +85,13 @@ def step(
self,
inputs: torch.Tensor,
prompts: Optional[torch.Tensor] = None,
hypo_ids: Optional[torch.Tensor] = None,
*,
hypo_ids: Optional[torch.Tensor] = None,
step_id: str,
) -> torch.Tensor:
"""
Inference step: send a chunk of input tensors and receive a chunk of outputs
:prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
:param prompts: optional DEEP prompts, added to a prefix of each layer's outputs,
if specified, deep prompts should have shape [num_layers, batch_size, prefix_len, hid_size]
"""
if self.closed:
Expand All @@ -111,8 +109,10 @@ def step(

if not self.stepped:
inputs = self.history # Pass full inputs including prefix
block_kwargs = self.block_kwargs
else:
inputs = inputs[:, -n_input_tokens:] # No need to pass prefix further
block_kwargs = []

if prompts is None or is_dummy(prompts):
prompts = DUMMY
Expand All @@ -129,38 +129,50 @@ def step(
assert len(hypo_ids) == len(inputs)
assert hypo_ids.dtype == torch.int64

# serialize inputs and put them into the queue
input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids)

request_metadata = dict(session_id=self.session_id, step_id=step_id)
if not self.stepped:
request_metadata.update(self.session_metadata)
elif self.config.use_server_to_server:
metadata = dict(session_id=self.session_id, step_id=step_id, max_length=self.max_length)
metadata.update(
self.sequence_manager.get_request_metadata(
self.span.peer_id,
"rpc_inference",
self.span_uids,
inputs,
prompts,
*block_kwargs,
max_length=self.max_length,
session_id=self.session_id,
step_id=step_id,
)
)
if self.stepped and self.sequence_manager.config.use_server_to_server:
next_servers = self._collect_next_servers()
if next_servers:
request_metadata["next_servers"] = next_servers
metadata["next_servers"] = next_servers

request_metadata["args_structure"] = args_structure
codecs = self.sequence_manager.get_compression_codecs(
self.span.peer_id, "rpc_inference", self.span_uids, inputs, prompts, *block_kwargs
)

# TODO: make possible to use different compression method for different tensors
server_side_inference_schema, kwargs_schema = self.rpc_info["inference_schema"]
compression = server_side_inference_schema[0].compression
inference_schema = tuple(BatchTensorDescriptor.from_tensor(arg, compression) for arg in input_tensors)
# serialize inputs and put them into the queue
input_tensors, args_structure = pack_args_kwargs(inputs, prompts, hypo_ids, *block_kwargs)
args_structure = metadata.setdefault("args_structure", args_structure)

# TODO: create more explicit way to check servers schema and client's structure
assert len(input_tensors) >= len(
server_side_inference_schema
), "Hidden_state, prompts and hypo_ids tensors are necessary for an inference step"
if codecs is None:
codecs = [runtime_pb2.CompressionType.NONE] * len(input_tensors)
else:
codecs = list(nested_flatten(codecs))
assert len(codecs) == len(
input_tensors
), f"got {len(input_tensors)} tensors but {len(codecs)} compression codecs"

outputs_serialized = RemoteExpertWorker.run_coroutine(
self._step(
runtime_pb2.ExpertRequest(
uid=self.uid,
uid=CHAIN_DELIMITER.join(self.span_uids),
tensors=[
serialize_torch_tensor(tensor.to(proto.dtype), proto.compression)
for tensor, proto in zip(input_tensors, inference_schema)
serialize_torch_tensor(tensor, compression)
for tensor, compression in zip(input_tensors, codecs)
],
metadata=MSGPackSerializer.dumps(request_metadata),
metadata=MSGPackSerializer.dumps(metadata),
)
)
)
Expand All @@ -187,7 +199,7 @@ async def _step(self, inputs_serialized: runtime_pb2.ExpertRequest) -> runtime_p
"""Inference step on serialized data. This code is meant to be run inside RemoteExpertWorker"""
await self._inputs_queue.put(inputs_serialized)
self.stepped = True
return await asyncio.wait_for(anext(self._outputs_stream), self.config.request_timeout)
return await asyncio.wait_for(anext(self._outputs_stream), self.sequence_manager.config.request_timeout)

def close(self):
"""Finish a given inference session, close the underlying connection"""
Expand Down Expand Up @@ -224,14 +236,20 @@ class InferenceSession:
An interface to a multi-step *inference* session for a sequence of remote transformer blocks
"""

def __init__(self, sequence_manager: RemoteSequenceManager, max_length: int):
def __init__(self, sequence_manager: RemoteSequenceManager, max_length: int, *block_kwargs: Dict[str, Any]):
self._sequence_manager = sequence_manager
self._closed = False
self._server_sessions = []
self._position = 0
self._max_length = max_length
self.output_ids = None

num_blocks = len(self._sequence_manager)
if len(block_kwargs) == 1:
block_kwargs = block_kwargs * num_blocks
assert len(block_kwargs) in (0, num_blocks), f"expected {num_blocks} block_kwargs, got {len(block_kwargs)}"
self.block_kwargs = block_kwargs

@property
def num_blocks(self) -> int:
return len(self._sequence_manager)
Expand All @@ -244,17 +262,13 @@ def _enter_server_sessions(self, chosen_spans: List[RemoteSpanInfo]) -> List[_Se
server_sessions = []
try:
for span in chosen_spans:
span_uids = CHAIN_DELIMITER.join(self._sequence_manager.block_uids[span.start : span.end])
metadata = self._sequence_manager.get_request_metadata("rpc_inference", span_uids, peer_id=span.peer_id)
session = RemoteExpertWorker.run_coroutine(
_ServerInferenceSession.create(
self._sequence_manager.config,
self._sequence_manager.state.p2p,
self._sequence_manager,
span,
span_uids,
rpc_info=self._sequence_manager.rpc_info,
self._sequence_manager.block_uids[span.start : span.end],
*self.block_kwargs[span.start : span.end],
max_length=self._max_length,
**metadata,
)
)
server_sessions.append(session)
Expand All @@ -275,7 +289,13 @@ def __enter__(self) -> "InferenceSession":
assert not self._closed and not self._server_sessions
return self

def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **kwargs) -> torch.Tensor:
def step(
self,
inputs: torch.Tensor,
prompts: Optional[torch.Tensor] = None,
hypo_ids: Optional[torch.Tensor] = None,
) -> torch.Tensor:

assert not self._closed
if torch.is_grad_enabled():
logger.warning("Running inference session with grad enabled. Gradients will *not* be propagated correctly.")
Expand Down Expand Up @@ -310,7 +330,10 @@ def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **k

server_session = self._server_sessions[server_idx]
inputs = server_session.step(
inputs, prompts[server_session.span.start : server_session.span.end], step_id=step_id, **kwargs
inputs,
prompts[server_session.span.start : server_session.span.end],
hypo_ids=hypo_ids,
step_id=step_id,
)

server_idx += 1
Expand All @@ -336,7 +359,7 @@ def step(self, inputs: torch.Tensor, prompts: Optional[torch.Tensor] = None, **k
outputs = outputs.to(device=inputs_device, dtype=inputs_dtype)
return outputs

def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int) -> int:
def _update_sequence(self, server_idx: int, block_idx: int, attempt_no: int):
# If there is a failed server session, this code closes it
self._exit_server_sessions(self._server_sessions[server_idx : server_idx + 1])

Expand Down