from __future__ import annotations
import functools
import os
import itertools
import sys
import pathlib
import subprocess
import json
import contextlib
import copy
from typing import List
from typing import Optional
from typing import Dict
from typing import Set
import attr
import trio
import eliot
from voca import plugins
from voca import utils
from voca import streaming
from voca import log
[docs]def handle_unexpected_worker_bytes(message: bytes):
eliot.Message.log(message_type="unexpected_worker_output", message=message)
[docs]def worker_cli(should_log, module_names: Optional[List[str]] = None) -> List[str]:
"""Build the list of strings for invoking a worker subprocess."""
if module_names is None:
module_names = utils.plugin_module_paths()
log_arg = "--log" if should_log else "--no-log"
prefix = [sys.executable, "-m", "voca", log_arg, "worker"]
command = prefix.copy()
for module_name in module_names:
command += ["-i", module_name]
return command
[docs]@log.log_async_call
async def replay_child_messages(child: trio.Process) -> None:
"""Log the child's messages to the parent's stdout."""
async for message_from_child in streaming.TerminatedFrameReceiver(
child.stdout, b"\n"
):
message_str = message_from_child.decode()
try:
message_dict = json.loads(message_str)
except json.decoder.JSONDecodeError:
handle_unexpected_worker_bytes(message_from_child)
else:
print(message_dict)
[docs]@log.log_call
def set_state(data: Dict[str, dict], state: Dict[str, dict]):
"""Handle switching between ``eager`` and ``strict`` mode."""
state = copy.deepcopy(state)
body = data["result"]["hypotheses"][0]["transcript"]
command, _space, _args = body.partition(" ")
if command == "mode":
state["modes"]["strict"] = not state["modes"]["strict"]
return state
return None
[docs]@log.log_async_call
async def delegate_task(
data: Dict, worker: trio.Process, state: dict, action: eliot.Action
):
"""Send input data to worker process over std streams."""
wrapped_data = dict(
**data, state=state, eliot_task_id=action.serialize_task_id().decode()
)
await worker.stdin.send_all(json.dumps(wrapped_data).encode() + b"\n")
[docs]@attr.s
class Pool:
num_workers: int = attr.ib(default=1)
should_log: bool = attr.ib(default=True)
module_names: List[str] = attr.ib(factory=list)
processes: Set[trio.Process] = attr.ib(factory=set)
[docs] def start(self) -> None:
"""Start a new process."""
for _ in range(self.num_workers):
self.add_new_process()
[docs] def get_process(self) -> trio.Process:
"""Pop an process outof the pool and return it."""
return self.processes.pop()
[docs] def add_new_process(self) -> None:
"""Start a new process and add it to the pool."""
self.processes.add(
trio.Process(
worker_cli(self.should_log, self.module_names),
stdin=subprocess.PIPE,
stdout=subprocess.PIPE,
)
)
[docs]@log.log_async_call
async def run_worker(data: dict, state: dict, pool: Pool):
"""Get a process from the pool, send a job to it. Replace that worker when it quits."""
with eliot.start_action(action_type="run_with_work") as action:
worker = pool.get_process()
await delegate_task(data=data, state=state, worker=worker, action=action)
await replay_child_messages(worker)
await worker.wait()
pool.add_new_process()
[docs]@log.log_async_call
async def process_stream(
receiver, num_workers: int, should_log: bool, module_names: Optional[List[str]]
):
"""Handle all the commands coming in by delegating them to workers."""
state = {"modes": {"strict": True}}
pool = Pool(num_workers, should_log=should_log, module_names=module_names)
pool.start()
async for message_bytes in receiver:
message = message_bytes.decode()
with eliot.start_action(state=state):
try:
data = json.loads(message)
except json.JSONDecodeError:
handle_unexpected_worker_bytes(message_bytes)
continue
if "result" not in data.keys():
# Received a log, not a command.
print(message)
continue
# Handle state changes.
maybe_new_state = set_state(data, state)
if maybe_new_state is not None:
state = maybe_new_state
continue
# This logic could be moved into worker/plugin to allow for more modes.
if not data["result"]["final"] and state["modes"]["strict"]:
continue
if data["result"]["final"] and not state["modes"]["strict"]:
continue
await run_worker(data=data, state=state, pool=pool)
[docs]@log.log_async_call
async def async_main(should_log, module_names: Optional[List[str]], num_workers: int):
"""Read newline-separated inputs on stdin, and process them."""
stream = trio._unix_pipes.PipeReceiveStream(os.dup(0))
receiver = streaming.TerminatedFrameReceiver(stream, b"\n")
await process_stream(
receiver,
num_workers=num_workers,
should_log=should_log,
module_names=module_names,
)
[docs]@log.log_call
def main(should_log: bool, module_names: Optional[List[str]], num_workers: int):
"""Start the event loop."""
trio.run(functools.partial(async_main, should_log, module_names, num_workers))