Source code for pykern.pksubprocess

"""Wrapper for subprocess.

:copyright: Copyright (c) 2016 RadiaSoft LLC.  All Rights Reserved.
:license: http://www.apache.org/licenses/LICENSE-2.0.html
"""

from pykern.pkdebug import pkdc, pkdexc, pkdp, pkdlog
import os
import signal
import six
import subprocess
import threading

#: Caught signals
_SIGNALS = (signal.SIGTERM, signal.SIGINT)


[docs] def check_call_with_signals(cmd, output=None, env=None, msg=None, recursive_kill=False): """Run cmd, writing to output. stdin is `os.devnull`. Passes SIGTERM and SIGINT on to the child process. If `output` is a string, it will be opened in write ('w') mode. Args: cmd (list): passed to subprocess verbatim output (file or str): where to write stdout and stderr env (dict): environment to use recursive_kill (bool): EXPERIMENTAL: kill all process children, recursively """ assert ( _is_main_thread() ), "subprocesses which require signals need to be started in main thread" p = None prev_signal = dict([(sig, signal.getsignal(sig)) for sig in _SIGNALS]) pid = None all_pids = set() def signal_handler(sig, frame): try: if p: p.send_signal(sig) except Exception: # Nothing we can do, still want to cascade pass finally: ps = prev_signal[sig] if ps in (None, signal.SIG_IGN, signal.SIG_DFL): return ps(sig, frame) def wait_pid(): """Iteratively and recursively gather all children mpiexec sets a session, and doesn't cascade signals so processes can stay running after an exit. """ # always SIGKILL the process we started all_pids.add(pid) if not recursive_kill: # simple process running return p.wait() import psutil, time # EXPERIMENTAL z = psutil.Process(pid) t = 0.1 while True: all_pids.update( (c.pid for c in z.children(recursive=True)), ) x, s = os.waitpid(pid, os.WNOHANG) if x != 0: break time.sleep(t) # first sleep is very fast, just in case a fast # process starts. Then polling less frequently # helps avoid thrashing, especially with mpi. t = 0.5 return s try: stdout = output if isinstance(output, six.string_types): stdout = open(output, "w") stderr = subprocess.STDOUT if stdout else None for sig in _SIGNALS: signal.signal(sig, signal_handler) p = subprocess.Popen( cmd, stdin=open(os.devnull), stdout=stdout, stderr=stderr, env=env, ) pid = p.pid if msg: msg("{}: started: {}", pid, cmd) s = wait_pid() p = None if s != 0: raise RuntimeError("error exit({})".format(s)) if msg: msg("{}: normal exit(0): {}", pid, cmd) except Exception as e: if msg: msg("{}: exception: {} {}", pid, cmd, pkdexc()) raise finally: for sig in _SIGNALS: signal.signal(sig, prev_signal[sig]) if p is not None: if msg: msg("{}: terminating: {}", pid, cmd) try: p.terminate() time.sleep(0.1) except Exception: pass for x in all_pids: try: os.kill(x, signal.SIGKILL) # maybe we didn't catch all the children so try this os.killpg(x, signal.SIGKILL) except Exception: pass if stdout != output: stdout.close()
def _is_main_thread(): """Need to determine if the main thread for setting signals Returns: bool: if running in the main thread """ if hasattr(threading, "main_thread"): # Python 3 return threading.current_thread() == threading.main_thread() # Python 2: See http://stackoverflow.com/a/23207116 return threading.current_thread().__class__ == threading._MainThread