Skip to content

Commit

Permalink
suggested improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
madhur-ob committed Apr 16, 2024
1 parent 9c95627 commit c681f71
Showing 1 changed file with 75 additions and 45 deletions.
120 changes: 75 additions & 45 deletions metaflow/subprocess_manager.py
Original file line number Diff line number Diff line change
@@ -1,94 +1,112 @@
import os
import sys
import time
import signal
import shutil
import hashlib
import asyncio
import tempfile
from typing import List


def hash_command_invocation(command: List[str]):
concatenated_string = "".join(command)
current_time = str(time.time())
concatenated_string += current_time
hash_object = hashlib.sha256(concatenated_string.encode())
return hash_object.hexdigest()
from typing import List, Dict, Optional, Callable


class LogReadTimeoutError(Exception):
"""Exception raised when reading logs times out."""

pass


class SubprocessManager(object):
"""A manager for subprocesses."""

def __init__(self):
self.commands = {}
self.commands: Dict[int, CommandManager] = {}

async def __aenter__(self):
async def __aenter__(self) -> "SubprocessManager":
return self

async def __aexit__(self, exc_type, exc_value, traceback):
await self.cleanup()

async def run_command(self, command: List[str], env=None, cwd=None):
command_id = hash_command_invocation(command)
self.commands[command_id] = CommandManager(command, env, cwd)
await self.commands[command_id].run()
return command_id
async def run_command(
self,
command: List[str],
env: Optional[Dict[str, str]] = None,
cwd: Optional[str] = None,
) -> int:
"""Run a command asynchronously and return its process ID."""

def get(self, command_id: str) -> "CommandManager":
return self.commands.get(command_id, None)
command_obj = CommandManager(command, env, cwd)
process = await command_obj.run()
self.commands[process.pid] = command_obj
return process.pid

async def cleanup(self):
for _, v in self.commands.items():
def get(self, pid: int) -> "CommandManager":
"""Get the CommandManager object for a given process ID."""

return self.commands.get(pid, None)

async def cleanup(self) -> None:
"""Clean up log files for all running subprocesses."""

for v in self.commands.values():
await v.cleanup()


class CommandManager(object):
def __init__(self, command: List[str], env=None, cwd=None):
self.command = command
"""A manager for an individual subprocess."""

if env is None:
env = os.environ.copy()
self.env = env
def __init__(
self,
command: List[str],
env: Optional[Dict[str, str]] = None,
cwd: Optional[str] = None,
):
self.command = command

if cwd is None:
cwd = os.getcwd()
self.cwd = cwd
self.env = env if env is not None else os.environ.copy()
self.cwd = cwd if cwd is not None else os.getcwd()

self.process = None
self.run_called = False
self.log_files = {}
self.run_called: bool = False
self.log_files: Dict[str, str] = {}

signal.signal(signal.SIGINT, self.handle_sigint)

async def __aenter__(self):
async def __aenter__(self) -> "CommandManager":
return self

async def __aexit__(self, exc_type, exc_value, traceback):
await self.cleanup()

def handle_sigint(self, signum, frame):
"""Handle the SIGINT signal."""

print("SIGINT received.")
asyncio.create_task(self.kill())

async def wait(self, timeout=None, stream=None):
async def wait(
self, timeout: Optional[float] = None, stream: Optional[str] = None
) -> None:
"""Wait for the subprocess to finish, optionally with a timeout and optionally streaming its output."""

if timeout is None:
if stream is None:
await self.process.wait()
else:
await self.emit_logs(stream)
else:
tasks = [asyncio.create_task(asyncio.sleep(timeout))]
if stream is None:
tasks.append(asyncio.create_task(self.process.wait()))
else:
tasks.append(asyncio.create_task(self.emit_logs(stream)))

await asyncio.wait(tasks, return_when="FIRST_COMPLETED")
try:
if stream is None:
await asyncio.wait_for(self.process.wait(), timeout)
else:
await asyncio.wait_for(self.emit_logs(stream), timeout)
except asyncio.TimeoutError:
command_string = " ".join(self.command)
print(
f"Timeout: The process: '{command_string}' didn't complete within {timeout} seconds."
)

async def run(self):
"""Run the subprocess, streaming the logs to temporary files"""

self.temp_dir = tempfile.mkdtemp()
stdout_logfile = os.path.join(self.temp_dir, "stdout.log")
stderr_logfile = os.path.join(self.temp_dir, "stderr.log")
Expand All @@ -114,14 +132,20 @@ async def run(self):
await self.cleanup()

async def stream_logs(
self, stream, position=None, timeout_per_line=None, log_write_delay=0.01
self,
stream: str,
position: Optional[int] = None,
timeout_per_line: Optional[float] = None,
log_write_delay: float = 0.01,
):
"""Stream logs from the subprocess using the log files"""

if self.run_called is False:
raise ValueError("No command run yet to get the logs for...")

if stream not in self.log_files:
raise ValueError(
f"No log file found for {stream}, valid values are: {list(self.log_files.keys())}"
f"No log file found for '{stream}', valid values are: {list(self.log_files.keys())}"
)

log_file = self.log_files[stream]
Expand Down Expand Up @@ -161,15 +185,21 @@ async def stream_logs(
position = f.tell()
yield position, line.strip()

async def emit_logs(self, stream="stdout", custom_logger=print):
async def emit_logs(self, stream: str = "stdout", custom_logger: Callable = print):
"""Helper function to iterate over stream_logs"""

async for _, line in self.stream_logs(stream):
custom_logger(line)

async def cleanup(self):
"""Clean up log files for a running subprocesses."""

if hasattr(self, "temp_dir"):
shutil.rmtree(self.temp_dir, ignore_errors=True)

async def kill(self, termination_timeout=5):
async def kill(self, termination_timeout: float = 5):
"""Kill the subprocess."""

if self.process is not None:
if self.process.returncode is None:
self.process.terminate()
Expand Down

0 comments on commit c681f71

Please sign in to comment.