Source code for pharmpy.workflows.workflow

from __future__ import annotations

import inspect
import uuid
from typing import Generic, Optional, TypeVar, Union

from pharmpy.deps import networkx as nx
from pharmpy.internals.immutable import Immutable

from .task import Task

T = TypeVar('T')


class WorkflowBase:
    """Common base class for Workflow and WorkflowBuilder"""

    _g: nx.DiGraph

    @property
    def tasks(self) -> list[Task]:
        """All tasks in workflow"""
        return list(self._g.nodes())

    @property
    def input_tasks(self) -> list[Task]:
        """All input (source) tasks of entire workflow"""
        return [node for node in self._g.nodes if self._g.in_degree(node) == 0]

    @property
    def output_tasks(self) -> list[Task]:
        """All output tasks (sink) tasks of entire workflow"""
        return [node for node in self._g.nodes if self._g.out_degree(node) == 0]

    def get_upstream_tasks(self, task: Task) -> list[Task]:
        """Get all tasks upstream of a certain task

        Parameters
        ----------
        task : Task
            Downstream task

        Returns
        -------
        list
            Upstream tasks
        """
        edges = nx.edge_dfs(self._g, task, orientation='reverse')
        return [node for node, _, _ in edges]

    def get_predecessors(self, task: Task) -> list[Task]:
        """Get all predecessors of task

        Parameters
        ----------
        task : Task
            Downstream task

        Returns
        -------
        list
            Predecessors of task
        """
        return list(self._g.predecessors(task))

    def get_successors(self, task: Task) -> list[Task]:
        """Get all successors of task

        Parameters
        ----------
        task : Task
            Upstream task

        Returns
        -------
        list
            Successors of task
        """
        return list(self._g.successors(task))

    def __len__(self):
        return len(self._g.nodes)


[docs] class WorkflowBuilder(WorkflowBase): """Builder class for Workflow""" def __init__(self, workflow=None, tasks=None, name=None): if workflow: self._g = workflow._g.copy() self.name = workflow.name else: self._g = nx.DiGraph() self.name = name if tasks: for task in tasks: self.add_task(task)
[docs] def add_task(self, task: Task, predecessors: Optional[Union[Task, list[Task]]] = None): """Add a task to the workflow Predecessors will be connected if given. Parameters ---------- task : Task Task to add predecessors : list or Task One or multiple predecessor tasks to connect to the added task """ self._g.add_node(task) if predecessors is not None: if not isinstance(predecessors, list): self._g.add_edge(predecessors, task) else: for pred in predecessors: self._g.add_edge(pred, task)
[docs] def replace_task(self, task: Task, new_task: Task): """Replace a task with a new task Parameters ---------- task : Task Task to replace new_task : Task New task """ mapping = {task: new_task} nx.relabel_nodes(self._g, mapping, copy=False)
[docs] def insert_workflow( self, other: Workflow, predecessors: Optional[Union[Task, list[Task]]] = None ): """Insert other workflow Parameters ---------- other : Workflow Workflow to insert predecessors : list or Task One or multiple predecessor tasks to connect to the inputs of the inserted workflow. If None all output tasks will be found and used as predecessors. """ if predecessors is None: output_tasks = self.output_tasks else: if isinstance(predecessors, list): output_tasks = predecessors else: output_tasks = [predecessors] input_tasks = other.input_tasks self._g = nx.compose(self._g, other._g) if len(input_tasks) == len(output_tasks): for inp, outp in zip(input_tasks, output_tasks): self._g.add_edge(outp, inp) elif len(input_tasks) == 1: for outp in output_tasks: self._g.add_edge(outp, input_tasks[0]) elif len(output_tasks) == 1: for inp in input_tasks: self._g.add_edge(output_tasks[0], inp) else: raise ValueError('Having N:M connections between workflows is currently not supported')
def __add__(self, other: Workflow): wb_new = WorkflowBuilder() wb_new._g = nx.compose(self._g, other._g) wb_new.name = self.name return wb_new
[docs] class Workflow(WorkflowBase, Generic[T], Immutable): """Workflow class Representation of a directed acyclic graph with Tasks as nodes and the flow of parameters as edges. Parameters ---------- tasks : list List of tasks for initialization """ def __init__(self, builder=None, graph=None, name: Optional[str] = None): if builder is not None: self._g = nx.freeze(builder._g.copy()) self._name = builder.name elif graph is not None: self._g = graph self._name = name
[docs] @classmethod def create(cls, builder=None, graph=None, name=None): return cls(builder=builder, graph=graph, name=name)
[docs] def replace(self, **kwargs): builder = kwargs.get('builder', None) if builder is None: g = kwargs.get('graph', self._g) name = kwargs.get('name', self._name) else: g = None name = None return Workflow.create(builder=builder, graph=g, name=name)
[docs] def as_dask_dict(self): """Create a dask workflow dictionary Returns ------- dict Dask workflow dictionary """ ids = {} for task in self._g.nodes(): assert isinstance(task, Task) ids[task] = f'{task.name}-{uuid.uuid4()}' if len(self.output_tasks) == 1: ids[self.output_tasks[0]] = 'results' else: raise ValueError("Workflow can only have one output task") as_dict = {} for task in nx.dfs_tree(self._g): key = ids[task] input_list = list(task.task_input) input_list.extend(ids[t] for t in self._g.predecessors(task)) value = (task.function, *input_list) as_dict[key] = value return as_dict
@property def name(self): """Name of Workflow""" return self._name
[docs] def plot_dask(self, filename: str): """Save a visualization of workflow to file Parameters ---------- filename : str Path """ from dask import visualize # pyright: ignore [reportPrivateImportUsage] visualize(self.as_dask_dict(), filename=filename, collapse_outputs=True)
def __str__(self): return '\n'.join([str(task) for task in self._g]) def __add__(self, other: Workflow): g = nx.compose(self._g, other._g) wf_new = Workflow(graph=g) return wf_new
def insert_context(wb: WorkflowBuilder, context): """Insert context for all tasks in a workflow needing it having context as first argument of function """ for task in wb.tasks: parameters = tuple(inspect.signature(task.function).parameters) if parameters and parameters[0] == 'context': new_task = task.replace(task_input=(context, *task.task_input)) wb.replace_task(task, new_task)