import numpy as np
from yt.funcs import mylog
from .parallel_analysis_interface import (
    ResultsStorage,
    _get_comm,
    communication_system,
    parallel_capable,
)
messages = {
    "task": {"msg": "next"},
    "result": {"msg": "result"},
    "task_req": {"msg": "task_req"},
    "end": {"msg": "no_more_tasks"},
}
[docs]
class TaskQueueNonRoot:
    def __init__(self, tasks, comm, subcomm):
        self.tasks = tasks
        self.results = {}
        self.comm = comm
        self.subcomm = subcomm
[docs]
    def send_result(self, result):
        new_msg = messages["result"].copy()
        new_msg["value"] = result
        if self.subcomm.rank == 0:
            self.comm.comm.send(new_msg, dest=0, tag=1)
        self.subcomm.barrier() 
    def __next__(self):
        msg = messages["task_req"].copy()
        if self.subcomm.rank == 0:
            self.comm.comm.send(msg, dest=0, tag=1)
            msg = self.comm.comm.recv(source=0, tag=2)
        msg = self.subcomm.bcast(msg, root=0)
        if msg["msg"] == messages["end"]["msg"]:
            mylog.debug("Notified to end")
            raise StopIteration
        return msg["value"]
    # For Python 2 compatibility
    next = __next__
    def __iter__(self):
        return self
[docs]
    def run(self, callable):
        for task in self:
            self.send_result(callable(task))
        return self.finalize() 
[docs]
    def finalize(self, vals=None):
        return self.comm.comm.bcast(vals, root=0) 
 
[docs]
class TaskQueueRoot(TaskQueueNonRoot):
    def __init__(self, tasks, comm, njobs):
        self.njobs = njobs
        self.tasks = tasks
        self.results = {}
        self.assignments = {}
        self._notified = 0
        self._current = 0
        self._remaining = len(self.tasks)
        self.comm = comm
        # Set up threading here
        # self.dist = threading.Thread(target=self.handle_assignments)
        # self.dist.daemon = True
        # self.dist.start()
[docs]
    def run(self, func=None):
        self.comm.probe_loop(1, self.handle_assignment)
        return self.finalize(self.results) 
[docs]
    def insert_result(self, source_id, rstore):
        task_id = rstore.result_id
        self.results[task_id] = rstore.result 
[docs]
    def assign_task(self, source_id):
        if self._remaining == 0:
            mylog.debug("Notifying %s to end", source_id)
            msg = messages["end"].copy()
            self._notified += 1
        else:
            msg = messages["task"].copy()
            task_id = self._current
            task = self.tasks[task_id]
            self.assignments[source_id] = task_id
            self._current += 1
            self._remaining -= 1
            msg["value"] = task
        self.comm.comm.send(msg, dest=source_id, tag=2) 
[docs]
    def handle_assignment(self, status):
        msg = self.comm.comm.recv(source=status.source, tag=1)
        if msg["msg"] == messages["result"]["msg"]:
            self.insert_result(status.source, msg["value"])
        elif msg["msg"] == messages["task_req"]["msg"]:
            self.assign_task(status.source)
        else:
            mylog.error("GOT AN UNKNOWN MESSAGE: %s", msg)
            raise RuntimeError
        if self._notified >= self.njobs:
            raise StopIteration 
 
[docs]
def task_queue(func, tasks, njobs=0):
    comm = _get_comm(())
    if not parallel_capable:
        mylog.error("Cannot create task queue for serial process.")
        raise RuntimeError
    my_size = comm.comm.size
    if njobs <= 0:
        njobs = my_size - 1
    if njobs >= my_size:
        mylog.error(
            "You have asked for %s jobs, but only %s processors are available.",
            njobs,
            (my_size - 1),
        )
        raise RuntimeError
    my_rank = comm.rank
    all_new_comms = np.array_split(np.arange(1, my_size), njobs)
    all_new_comms.insert(0, np.array([0]))
    for i, comm_set in enumerate(all_new_comms):
        if my_rank in comm_set:
            my_new_id = i
            break
    subcomm = communication_system.push_with_ids(all_new_comms[my_new_id].tolist())
    if comm.comm.rank == 0:
        my_q = TaskQueueRoot(tasks, comm, njobs)
    else:
        my_q = TaskQueueNonRoot(None, comm, subcomm)
    communication_system.pop()
    return my_q.run(func) 
[docs]
def dynamic_parallel_objects(tasks, njobs=0, storage=None, broadcast=True):
    comm = _get_comm(())
    if not parallel_capable:
        mylog.error("Cannot create task queue for serial process.")
        raise RuntimeError
    my_size = comm.comm.size
    if njobs <= 0:
        njobs = my_size - 1
    if njobs >= my_size:
        mylog.error(
            "You have asked for %s jobs, but only %s processors are available.",
            njobs,
            (my_size - 1),
        )
        raise RuntimeError
    my_rank = comm.rank
    all_new_comms = np.array_split(np.arange(1, my_size), njobs)
    all_new_comms.insert(0, np.array([0]))
    for i, comm_set in enumerate(all_new_comms):
        if my_rank in comm_set:
            my_new_id = i
            break
    subcomm = communication_system.push_with_ids(all_new_comms[my_new_id].tolist())
    if comm.comm.rank == 0:
        my_q = TaskQueueRoot(tasks, comm, njobs)
        my_q.comm.probe_loop(1, my_q.handle_assignment)
    else:
        my_q = TaskQueueNonRoot(None, comm, subcomm)
        if storage is None:
            for task in my_q:
                yield task
        else:
            for task in my_q:
                rstore = ResultsStorage()
                yield rstore, task
                my_q.send_result(rstore)
    if storage is not None:
        if broadcast:
            my_results = my_q.comm.comm.bcast(my_q.results, root=0)
        else:
            my_results = my_q.results
        storage.update(my_results)
    communication_system.pop()