Source code for pharmpy.tools.mfl.parse

import warnings
from collections import defaultdict
from itertools import product
from typing import List, Optional

from lark import Lark

from pharmpy.model import Model
from pharmpy.modeling.covariate_effect import get_covariate_effects
from pharmpy.modeling.odes import (
    get_number_of_peripheral_compartments,
    get_number_of_transit_compartments,
    has_first_order_absorption,
    has_first_order_elimination,
    has_instantaneous_absorption,
    has_lag_time,
    has_michaelis_menten_elimination,
    has_mixed_mm_fo_elimination,
    has_seq_zo_fo_absorption,
    has_zero_order_absorption,
    has_zero_order_elimination,
)
from pharmpy.tools.mfl.feature.covariate import features as covariate_features
from pharmpy.tools.mfl.statement.definition import Let
from pharmpy.tools.mfl.statement.feature.absorption import Absorption
from pharmpy.tools.mfl.statement.feature.covariate import Covariate
from pharmpy.tools.mfl.statement.feature.direct_effect import DirectEffect
from pharmpy.tools.mfl.statement.feature.effect_comp import EffectComp
from pharmpy.tools.mfl.statement.feature.elimination import Elimination
from pharmpy.tools.mfl.statement.feature.indirect_effect import IndirectEffect
from pharmpy.tools.mfl.statement.feature.lagtime import LagTime
from pharmpy.tools.mfl.statement.feature.metabolite import Metabolite
from pharmpy.tools.mfl.statement.feature.peripherals import Peripherals
from pharmpy.tools.mfl.statement.feature.symbols import Name
from pharmpy.tools.mfl.statement.feature.transits import Transits

from .grammar import grammar
from .helpers import (
    all_funcs,
    funcs,
    modelsearch_features,
    structsearch_metabolite_features,
    structsearch_pd_features,
)
from .interpreter import MFLInterpreter
from .statement.feature.covariate import Ref
from .statement.feature.symbols import Option, Wildcard
from .statement.statement import Statement
from .stringify import stringify as mfl_stringify


def parse(code: str, mfl_class=False) -> List[Statement]:
    mfl_statement_list = _parse(code)

    # TODO : only return class once it has been implemented everywhere
    if mfl_class:
        return ModelFeatures.create_from_mfl_statement_list(mfl_statement_list)
    else:
        return mfl_statement_list


def _parse(code: str):
    parser = Lark(
        grammar,
        start='start',
        parser='lalr',
        # lexer='standard',  # NOTE: This does not work because lexing for the
        #      MFL grammar is context-dependent
        propagate_positions=False,
        maybe_placeholders=False,
        debug=False,
        cache=True,
    )

    tree = parser.parse(code)

    mfl_statement_list = MFLInterpreter().interpret(tree)
    validate_mfl_list(mfl_statement_list)

    return mfl_statement_list


def validate_mfl_list(mfl_statement_list):
    # TODO : Implement for other features as necessary
    optional_cov = set()
    mandatory_cov = set()
    # FIXME (?) : Allow for same exact cov effect to be forced by multiple explicit statements
    for s in mfl_statement_list:
        if isinstance(s, Covariate):
            if not s.optional.option and isinstance(s.fp, Wildcard):
                raise ValueError(
                    f"Error in {mfl_stringify([s])} :"
                    f" Mandatory effects need to be explicit (not '*')"
                )
            if not isinstance(s.parameter, Ref) and not isinstance(s.covariate, Ref):
                if s.optional.option:
                    optional_cov.update(product(s.parameter, s.covariate))
                else:
                    if error := [
                        e for e in product(s.parameter, s.covariate) if e in mandatory_cov
                    ]:
                        raise ValueError(
                            f"Covariate effect(s) {error} is being forced by"
                            f" multiple statements. Please force only once"
                        )
                    mandatory_cov.update(product(s.parameter, s.covariate))


class ModelFeatures:
    def __init__(
        self,
        absorption=None,
        elimination=None,
        transits=tuple(),  # NOTE : This is a tuple
        peripherals=tuple(),
        lagtime=None,
        covariate=tuple(),  # Note : Should always be tuple (empty meaning no covariates)
        direct_effect=None,
        effect_comp=None,
        indirect_effect=tuple(),
        metabolite=None,
    ):
        self._absorption = absorption
        self._elimination = elimination
        self._transits = transits
        self._peripherals = peripherals
        self._lagtime = lagtime
        self._covariate = covariate
        self._direct_effect = direct_effect
        self._effect_comp = effect_comp
        self._indirect_effect = indirect_effect
        self._metabolite = metabolite

    @classmethod
    def create(
        cls,
        absorption=None,
        elimination=None,
        transits=tuple(),
        peripherals=tuple(),
        lagtime=None,
        covariate=tuple(),
        direct_effect=None,
        effect_comp=None,
        indirect_effect=tuple(),
        metabolite=None,
    ):
        # TODO : Check if allowed input value
        if absorption is not None and not isinstance(absorption, Absorption):
            raise ValueError(f"Absorption : {absorption} is not suppoerted")

        if elimination is not None and not isinstance(elimination, Elimination):
            raise ValueError(f"Elimination : {elimination} is not supported")

        if not isinstance(transits, tuple):
            raise ValueError("Transits need to be given within a tuple")
        if not all(isinstance(t, Transits) for t in transits):
            raise ValueError("All given elements of transits must be of type Transits")

        if not isinstance(peripherals, tuple):
            raise ValueError("Peripherals need to be given within a tuple")
        if not all(isinstance(p, Peripherals) for p in peripherals):
            raise ValueError(f"Peripherals : {peripherals} is not supported")

        if lagtime is not None and not isinstance(lagtime, LagTime):
            raise ValueError(f"Lagtime : {lagtime} is not supported")

        if not isinstance(covariate, tuple):
            raise ValueError("Covariates need to be given within a tuple")
        if not all(isinstance(c, Covariate) for c in covariate):
            raise ValueError(f"Covariate : {covariate} is not supported")

        if direct_effect is not None and not isinstance(direct_effect, DirectEffect):
            raise ValueError(f"DirectEffect : {direct_effect} is not supported")

        if effect_comp is not None and not isinstance(effect_comp, EffectComp):
            raise ValueError(f"EffectComp : {effect_comp} is not supported")

        if not isinstance(indirect_effect, tuple):
            raise ValueError("IndirectEffect(s) need to be given within a tuple")
        if not all(isinstance(i, IndirectEffect) for i in indirect_effect):
            raise ValueError("All given elements of indirect_effect must be of type IndirectEffect")

        if metabolite is not None and not isinstance(metabolite, Metabolite):
            raise ValueError(f"Metabolite : {metabolite} is not supported")

        # Indicate that we have a PK model, need default features
        if any(x for x in [absorption, elimination, transits, peripherals, lagtime, metabolite]):
            if absorption is None:
                absorption = Absorption((Name('INST'),))
            if elimination is None:
                elimination = Elimination((Name('FO'),))
            if transits == tuple():
                transits = (Transits((0,), (Name('DEPOT'),)),)
            if peripherals == tuple():
                peripherals += (Peripherals((0,)),)
            if lagtime is None:
                lagtime = LagTime((Name('OFF'),))

        return cls(
            absorption=absorption,
            elimination=elimination,
            transits=transits,
            peripherals=peripherals,
            lagtime=lagtime,
            covariate=covariate,
            direct_effect=direct_effect,
            effect_comp=effect_comp,
            indirect_effect=indirect_effect,
            metabolite=metabolite,
        )

    @classmethod
    def create_from_mfl_statement_list(cls, mfl_list):
        absorption = None
        elimination = None
        transits = tuple()
        peripherals = tuple()
        lagtime = None
        covariate = tuple()
        direct_effect = None
        effect_comp = None
        indirect_effect = tuple()
        metabolite = None
        let = {}

        for statement in mfl_list:
            if isinstance(statement, Absorption):
                absorption = absorption + statement if absorption else statement
            elif isinstance(statement, Elimination):
                elimination = elimination + statement if elimination else statement
            elif isinstance(statement, Transits):
                transits += (statement,)
            elif isinstance(statement, Peripherals):
                peripherals += (statement,)
            elif isinstance(statement, LagTime):
                lagtime = lagtime + statement if lagtime else statement
            elif isinstance(statement, Covariate):
                covariate += (statement,)
            elif isinstance(statement, Let):
                let[statement.name] = statement.value
            elif isinstance(statement, DirectEffect):
                direct_effect = direct_effect + statement if direct_effect else statement
            elif isinstance(statement, EffectComp):
                effect_comp = effect_comp + statement if effect_comp else statement
            elif isinstance(statement, IndirectEffect):
                indirect_effect += (statement,)
            elif isinstance(statement, Metabolite):
                metabolite = metabolite + statement if metabolite else statement
            else:
                raise ValueError(f'Unknown ({type(statement)} statement ({statement}) given.')

        # Substitute all Let statements (if any)
        if len(let) != 0:
            # FIXME : Multiple let statements for the same reference value ?
            def _let_subs(cov, let):
                return Covariate(
                    parameter=(
                        cov.parameter
                        if not (isinstance(cov.parameter, Ref) and cov.parameter.name in let)
                        else let[cov.parameter.name]
                    ),
                    covariate=(
                        cov.covariate
                        if not (isinstance(cov.covariate, Ref) and cov.covariate.name in let)
                        else let[cov.covariate.name]
                    ),
                    fp=cov.fp,
                    op=cov.op,
                    optional=cov.optional,
                )

            # Add other attributes as necessary
            covariate = tuple([_let_subs(cov, let) for cov in covariate])

        mfl = cls.create(
            absorption=absorption,
            elimination=elimination,
            transits=transits,
            peripherals=peripherals,
            lagtime=lagtime,
            covariate=covariate,
            direct_effect=direct_effect,
            effect_comp=effect_comp,
            indirect_effect=indirect_effect,
            metabolite=metabolite,
        )
        return mfl

    @classmethod
    def create_from_mfl_string(cls, mfl_string):
        return parse(mfl_string, mfl_class=True)

    def replace(self, **kwargs):
        absorption = kwargs.get("absorption", self._absorption)
        elimination = kwargs.get("elimination", self._elimination)
        transits = kwargs.get("transits", self._transits)
        peripherals = kwargs.get("peripherals", self._peripherals)
        lagtime = kwargs.get("lagtime", self._lagtime)
        covariate = kwargs.get("covariate", self._covariate)
        direct_effect = kwargs.get("direct_effect", self._direct_effect)
        effect_comp = kwargs.get("effect_comp", self._effect_comp)
        indirect_effect = kwargs.get("indirect_effect", self._indirect_effect)
        metabolite = kwargs.get("metabolite", self._metabolite)
        return ModelFeatures.create(
            absorption=absorption,
            elimination=elimination,
            transits=transits,
            peripherals=peripherals,
            lagtime=lagtime,
            covariate=covariate,
            direct_effect=direct_effect,
            effect_comp=effect_comp,
            indirect_effect=indirect_effect,
            metabolite=metabolite,
        )

    def replace_features(self, mfl_str):
        key_dict = {
            Absorption: 'absorption',
            Elimination: 'elimination',
            Transits: 'transits',
            Peripherals: 'peripherals',
            LagTime: 'lagtime',
            Covariate: 'covariate',
            DirectEffect: 'direct_effect',
            EffectComp: 'effect_comp',
            IndirectEffect: 'indirect_effect',
            Metabolite: 'metabolite',
        }
        mfl_list = _parse(mfl_str)
        kwargs = {}
        for statement in mfl_list:
            key = key_dict[type(statement)]
            if type(statement) in (Transits, Peripherals, Covariate, IndirectEffect):
                value = (statement,)
            else:
                value = statement
            if key in kwargs.keys():
                kwargs[key] += value
            else:
                kwargs[key] = value
        return self.replace(**kwargs)

    @property
    def absorption(self):
        return self._absorption

    @property
    def elimination(self):
        return self._elimination

    @property
    def transits(self):
        return self._transits

    @property
    def peripherals(self):
        return self._peripherals

    @property
    def lagtime(self):
        return self._lagtime

    @property
    def covariate(self):
        return self._covariate

    @property
    def direct_effect(self):
        return self._direct_effect

    @property
    def effect_comp(self):
        return self._effect_comp

    @property
    def indirect_effect(self):
        return self._indirect_effect

    @property
    def metabolite(self):
        return self._metabolite

    def expand(self, model):
        explicit_covariates = set(
            [
                p
                for c in self.covariate
                if (not isinstance(c.parameter, Ref) and not isinstance(c.covariate, Ref))
                for p in product(c.parameter, c.covariate)
            ]
        )  # Override @ reference with explicit value
        covariate = tuple(
            c for c in [c.eval(model, explicit_covariates) for c in self.covariate] if c is not None
        )

        param_cov = [
            p for c in covariate for p in product(c.parameter, c.covariate) if not c.optional.option
        ]
        counts = [(c, param_cov.count(c)) for c in param_cov]
        if any(c[1] > 1 for c in counts):
            error = set(c[0] for c in filter(lambda c: c[1] > 1, counts))
            raise ValueError(
                f"Covariate effect(s) {error} is forced by multiple reference statements."
                f" Please redefine the search space."
            )

        # Overwrite optional covariates if forced
        if covariate:
            covariate_combinations = set()
            for cov in covariate:
                covariate_combinations.update(
                    set(
                        product(
                            cov.parameter, cov.covariate, cov.eval().fp, (cov.op,), (cov.optional,)
                        )
                    )
                )
            for combination in covariate_combinations.copy():
                if not combination[4].option:
                    opposite = combination[:-1] + (Option(True),)
                    covariate_combinations.discard(opposite)

            covariate_combinations = [tuple({x} for x in e) for e in covariate_combinations]
            covariate_combinations = _reduce_covariate(covariate_combinations)
            covariate = tuple()
            for param, cov, fp, op, opt in covariate_combinations:
                covariate += (
                    Covariate(tuple(param), tuple(cov), tuple(fp), list(op)[0], list(opt)[0]),
                )

        return ModelFeatures.create(
            absorption=self.absorption.eval if self.absorption else None,
            elimination=self.elimination.eval if self.elimination else None,
            transits=tuple([t.eval for t in self.transits]),
            peripherals=tuple([p.eval for p in self.peripherals]),
            lagtime=self.lagtime.eval if self.lagtime else None,
            covariate=covariate,
            direct_effect=self.direct_effect.eval if self.direct_effect else None,
            effect_comp=self.effect_comp.eval if self.effect_comp else None,
            indirect_effect=tuple([i.eval for i in self.indirect_effect]),
            metabolite=self.metabolite.eval if self.metabolite else None,
        )

    def mfl_statement_list(self, attribute_type: Optional[List[str]] = []):
        """Add the repspective MFL attributes to a list"""

        # NOTE : This function is needed to be able to convert the classes to functions

        if not attribute_type:
            attribute_type = [
                "absorption",
                "elimination",
                "transits",
                "peripherals",
                "lagtime",
                "covariate",
                "direct_effect",
                "effect_comp",
                "indirect_effect",
                "metabolite",
            ]
        mfl_list = []
        if "absorption" in attribute_type:
            mfl_list.append(self.absorption)
        if "elimination" in attribute_type:
            mfl_list.append(self.elimination)
        if "transits" in attribute_type:
            for t in self.transits:
                mfl_list.append(t)
        if "peripherals" in attribute_type:
            for p in self.peripherals:
                mfl_list.append(p)
        if "lagtime" in attribute_type:
            mfl_list.append(self.lagtime)
        if "covariate" in attribute_type:
            for c in self.covariate:
                mfl_list.append(c)
        if "direct_effect" in attribute_type:
            mfl_list.append(self.direct_effect)
        if "effect_comp" in attribute_type:
            mfl_list.append(self.effect_comp)
        if "indirect_effect" in attribute_type:
            for i in self.indirect_effect:
                mfl_list.append(i)
        if "metabolite" in attribute_type:
            mfl_list.append(self.metabolite)

        return [m for m in mfl_list if m is not None]

    def filter(self, subset):
        if subset == "pk":
            peripherals = self._extract_peripherals()
            if peripherals["DRUG"]:
                peripherals = (Peripherals(tuple(peripherals["DRUG"]), (Name("DRUG"),)),)
            else:
                peripherals = tuple()
            return ModelFeatures.create(
                absorption=self.absorption,
                elimination=self.elimination,
                transits=self.transits,
                peripherals=peripherals,
                lagtime=self.lagtime,
            )
        elif subset == "pd":
            return ModelFeatures.create(
                direct_effect=self.direct_effect,
                effect_comp=self.effect_comp,
                indirect_effect=self.indirect_effect,
            )
        elif subset == "metabolite":
            peripherals = self._extract_peripherals()
            if peripherals["MET"]:
                peripherals = (Peripherals(tuple(peripherals["MET"]), (Name("MET"),)),)
            else:
                peripherals = tuple()
            return ModelFeatures.create(peripherals=peripherals, metabolite=self.metabolite)
        else:
            raise ValueError(f"Unknown subset {subset}")

    def convert_to_funcs(
        self,
        attribute_type: Optional[List[str]] = None,
        model: Optional[Model] = None,
        subset_features=None,
    ):
        if subset_features == "pk":
            filtered_mfl = self.filter(subset_features)
            return funcs(
                model, filtered_mfl.mfl_statement_list(attribute_type), modelsearch_features
            )
        elif subset_features == "pd":
            filtered_mfl = self.filter(subset_features)
            return funcs(
                model, filtered_mfl.mfl_statement_list(attribute_type), structsearch_pd_features
            )
        elif subset_features == "metabolite":
            filtered_mfl = self.filter(subset_features)
            return funcs(
                model,
                filtered_mfl.mfl_statement_list(attribute_type),
                structsearch_metabolite_features,
            )
        else:
            # The model argument is used for when extacting covariates.
            if not model:
                model = Model()
            return all_funcs(model, self.mfl_statement_list(attribute_type))

    def contain_subset(self, mfl, model: Optional[Model] = None, tool: Optional[str] = None):
        """See if class contain specified subset"""
        transits = self._subset_transits(mfl)
        peripheral_lhs = self._extract_peripherals()
        peripheral_rhs = mfl._extract_peripherals()

        if (
            all([s in self.absorption.eval.modes for s in mfl.absorption.eval.modes])
            and all([s in self.elimination.eval.modes for s in mfl.elimination.eval.modes])
            and transits
            and all([s in self.lagtime.eval.modes for s in mfl.lagtime.eval.modes])
        ):
            if tool is None or tool in ["modelsearch"]:
                return all(p in peripheral_lhs["DRUG"] for p in list(peripheral_rhs["DRUG"]))
            else:
                if not (
                    all(p in peripheral_lhs["DRUG"] for p in list(peripheral_rhs["DRUG"]))
                    and all(p in peripheral_lhs["MET"] for p in list(peripheral_rhs["MET"]))
                ):
                    return False
                if self.covariate != tuple() or mfl.covariate != tuple():
                    if model is None:
                        warnings.warn("Need argument 'model' in order to compare covariates")
                    else:
                        return True if self._subset_covariate(mfl, model) else False
        else:
            return False

    def _subset_transits(self, mfl):
        lhs_counts = set([c for t in self.transits for c in t.counts])
        lhs_depot = set([d for t in self.transits for d in t.eval.depot])

        rhs_counts = set([c for t in mfl.transits for c in t.counts])
        rhs_depot = set([d for t in mfl.transits for d in t.eval.depot])
        # FIXME : Need to compare counts per depot individually when comparing two
        # search spaces (Currenty working for model vs search space)
        return all([c in lhs_counts for c in rhs_counts]) and all(
            [d in lhs_depot for d in rhs_depot]
        )

    def _subset_covariates(self, mfl, model):
        lhs = defaultdict(list)
        rhs = defaultdict(list)

        for cov in self.covariate:
            cov_eval = cov.eval(model)
            for effect in cov_eval.fp:
                for op in cov_eval.op:
                    lhs[(effect, op)].append(product(cov_eval.parameter, cov_eval.covariate))
        for cov in mfl.covariate:
            cov_eval = cov.eval(model)
            for effect in cov_eval.fp:
                for op in cov_eval.op:
                    rhs[(effect, op)].append(product(cov_eval.parameter, cov_eval.covariate))

        for key in rhs.keys():
            if key not in lhs.keys():
                return False
            if all(p in lhs[key] for p in rhs[key]):
                continue
            else:
                return False
        return True

    def least_number_of_transformations(
        self, other, model: Optional[Model] = None, tool: Optional[str] = None
    ):
        """The smallest set of transformations to become part of other"""

        def _lnt_helper(lhs, rhs, mfl, name, lnt):
            if lhs is None and rhs is not None or lhs is not None and rhs is None:
                raise ValueError(
                    f"{name} : is only part of one of the MFLs" " and therefore cannot be compared"
                )
            if lhs is None and rhs is None:
                return lnt
            if not any(x in rhs.eval.modes for x in lhs.eval.modes):
                name, func = list(mfl.convert_to_funcs([name]).items())[0]
                lnt[name] = func
            return lnt

        # Add more tools than "modelsearch" if support is needed
        lnt = {}
        if tool is None or tool in ["modelsearch"]:
            lnt = _lnt_helper(self.absorption, other.absorption, other, "absorption", lnt)

            lnt = _lnt_helper(self.elimination, other.elimination, other, "elimination", lnt)

            lnt = self._lnt_transits(other, lnt)

            lnt = self._lnt_peripherals(other, lnt, "pk")

            lnt = _lnt_helper(self.lagtime, other.lagtime, other, "lagtime", lnt)

        # TODO : Use in covsearch instead of taking diff
        if tool is None:
            if model is not None:
                lnt = self._lnt_covariates(other, lnt, model)
            else:
                if self.covariate != tuple() or other.covariate != tuple():
                    warnings.warn("Need argument 'model' in order to compare covariates")

            lnt = self._lnt_peripherals(other, lnt, "metabolite")

            lnt = _lnt_helper(self.direct_effect, other.direct_effect, other, "direct_effect", lnt)

            lnt = _lnt_helper(self.effect_comp, other.effect_comp, other, "effect_comp", lnt)

            lnt = self._lnt_indirect_effect(other, lnt)

            lnt = _lnt_helper(self.metabolite, other.metabolite, other, "metabolite", lnt)

        return lnt

    def _lnt_indirect_effect(self, other, lnt):
        lhs, rhs, combine = _add_helper(
            self.indirect_effect, other.indirect_effect, "modes", "production"
        )
        if not combine and rhs:
            # No shared attribute
            func_dict = other.convert_to_funcs(["indirect_effect"])
            for key in lhs.keys():
                if key in rhs.keys():
                    lnt[('INDIRECT', rhs[key][0], key.name)] = func_dict[
                        ('INDIRECT', rhs[key][0], key.name)
                    ]
                    return lnt
            # No key is matching
            key = next(iter(rhs))
            lnt[('INDIRECT', rhs[key][0], key.name)] = func_dict[
                ('INDIRECT', rhs[key][0], key.name)
            ]
            return lnt
        return lnt

    def _lnt_transits(self, other, lnt):
        lhs, rhs, combine = _add_helper(self.transits, other.transits, "counts", "depot")
        if not combine and rhs:
            # No shared attribute
            func_dict = other.convert_to_funcs(["transits"])
            for key in lhs.keys():
                if key in rhs.keys():
                    lnt[('TRANSITS', rhs[key][0], key.name)] = func_dict[
                        ('TRANSITS', rhs[key][0], key.name)
                    ]
                    return lnt
            # No key is matching
            key = next(iter(rhs))
            lnt[('TRANSITS', rhs[key][0], key.name)] = func_dict[
                ('TRANSITS', rhs[key][0], key.name)
            ]
            return lnt
        return lnt

    def _lnt_peripherals(self, other, lnt, subset):
        if subset == "pk":
            keys = ["DRUG"]
        if subset == "metabolite":
            keys = ["MET"]
        else:
            keys = ["DRUG", "MET"]
        lhs = self._extract_peripherals()
        rhs = other._extract_peripherals()
        func_dict = other.convert_to_funcs(["peripherals"])
        for key in keys:
            if not any(c in rhs[key] for c in lhs[key]):
                if key == "DRUG":
                    if rhs[key]:
                        lnt[("PERIPHERALS", min(rhs[key]))] = func_dict[
                            ("PERIPHERALS", min(rhs[key]))
                        ]
                elif key == "MET":
                    if rhs[key]:
                        lnt[("PERIPHERALS", min(rhs[key]), "METABOLITE")] = func_dict[
                            ("PERIPHERALS", min(rhs[key]), "METABOLITE")
                        ]
        return lnt

    def _lnt_covariates(self, other, lnt, model):
        lhs = self._extract_covariates()
        lhs = [c for c in lhs if not c[4].option]  # Check only FORCED
        rhs = other._extract_covariates()
        rhs = [c for c in rhs if not c[4].option]  # Check only FORCED

        def convert_to_covariate(combinations):
            cov_list = []
            for param, cov, fp, op, opt in combinations:
                cov_list.append(Covariate((param,), (cov,), (fp,), op, opt))
            return cov_list

        # Remove all unqiue to LHS
        lhs_unique = [c for c in lhs if c not in rhs]
        if lhs_unique:
            cov_list = convert_to_covariate(lhs_unique)
            lhs_mfl = ModelFeatures.create_from_mfl_statement_list(cov_list)
            remove_cov_dict = dict(covariate_features(model, lhs_mfl.covariate, remove=True))
            lnt.update(remove_cov_dict)

        # Add all unique to RHS
        rhs_unique = [c for c in rhs if c not in lhs]
        if rhs_unique:
            cov_list = convert_to_covariate(rhs_unique)
            lhs_mfl = ModelFeatures.create_from_mfl_statement_list(cov_list)
            add_cov_dict = dict(covariate_features(model, lhs_mfl.covariate, remove=False))
            lnt.update(add_cov_dict)

        return lnt

    def __repr__(self):
        # TODO : Remove default values
        return mfl_stringify(self.mfl_statement_list())

    def __sub__(self, other):
        def sub(lhs, rhs):
            if lhs:
                if rhs:
                    if lhs == rhs:
                        return None
                    else:
                        return lhs - rhs
                else:
                    return lhs
            else:
                return lhs

        transits = self._add_sub_transits(other, add=False)
        peripherals = self._add_sub_peripherals(other, add=False)
        covariates = self._add_sub_covariates(other, add=False)
        indirect_effect = self._add_sub_indirect_effect(other, add=False)

        return ModelFeatures.create(
            absorption=sub(self.absorption, other.absorption),
            elimination=sub(self.elimination, other.elimination),
            transits=transits,
            peripherals=peripherals,
            lagtime=sub(self.lagtime, other.lagtime),
            covariate=covariates,
            direct_effect=sub(self.direct_effect, other.direct_effect),
            effect_comp=sub(self.effect_comp, other.effect_comp),
            indirect_effect=indirect_effect,
            metabolite=sub(self.metabolite, other.metabolite),
        )

    def __add__(self, other):
        def add(lhs, rhs):
            if lhs:
                if rhs:
                    return lhs + rhs
                else:
                    return lhs
            elif rhs:
                return rhs
            else:
                return lhs

        transits = self._add_sub_transits(other, add=True)
        peripherals = self._add_sub_peripherals(other, add=True)
        covariates = self._add_sub_covariates(other, add=True)
        indirect_effect = self._add_sub_indirect_effect(other, add=True)

        return ModelFeatures.create(
            absorption=add(self.absorption, other.absorption),
            elimination=add(self.elimination, other.elimination),
            transits=transits,
            peripherals=peripherals,
            lagtime=add(self.lagtime, other.lagtime),
            covariate=covariates,
            direct_effect=add(self.direct_effect, other.direct_effect),
            effect_comp=add(self.effect_comp, other.effect_comp),
            indirect_effect=indirect_effect,
            metabolite=add(self.metabolite, other.metabolite),
        )

    def _add_sub_peripherals(self, other, add=True):
        lhs = self._extract_peripherals()
        rhs = other._extract_peripherals()

        combined = {"MET": tuple(), "DRUG": tuple()}

        for key in combined.keys():
            if add:
                combined[key] = tuple(lhs[key].union(rhs[key]))
            else:
                combined[key] = tuple(lhs[key].difference(rhs[key]))

        peripherals = []
        for k, v in combined.items():
            if v:  # Not an empty tuple
                peripherals.append(Peripherals(v, (Name(k),)))
        if peripherals:
            return tuple(peripherals)
        else:
            return tuple()

    def _add_sub_indirect_effect(self, other, add=True):
        """Apply logic for adding/subtracting mfl IndirectEffect(s).
        Use add = False for subtraction"""
        # TODO : combine with _add_sub_transits()
        lhs, rhs, combined = _add_helper(
            self.indirect_effect, other.indirect_effect, "modes", "production"
        )

        def convert_to_indirect_effect(d):
            indirect_effects = []
            for k, v in d.items():
                indirect_effects.append(IndirectEffect(v, (k,)))
            return indirect_effects

        lhs_indirect_effects = convert_to_indirect_effect(lhs)
        rhs_indirect_effects = convert_to_indirect_effect(rhs)
        combined_indirect_effects = convert_to_indirect_effect(combined)

        # TODO : Cleanup and combine all possible statements
        if add:
            return tuple(lhs_indirect_effects + rhs_indirect_effects + combined_indirect_effects)
        else:
            if lhs_indirect_effects:
                return tuple(lhs_indirect_effects)
            else:
                return tuple()

    def _add_sub_transits(self, other, add=True):
        """Apply logic for adding/subtracting mfl transits.
        Use add = False for subtraction"""
        lhs, rhs, combined = _add_helper(self.transits, other.transits, "counts", "depot")

        def convert_to_transits(d):
            transits = []
            for k, v in d.items():
                transits.append(Transits(v, (k,)))
            return transits

        lhs_transits = convert_to_transits(lhs)
        rhs_transits = convert_to_transits(rhs)
        combined_transits = convert_to_transits(combined)

        # TODO : Cleanup and combine all possible statements
        if add:
            return tuple(lhs_transits + rhs_transits + combined_transits)
        else:
            if lhs_transits:
                return tuple(lhs_transits)
            else:
                return tuple()

    def _add_sub_covariates(self, other, add=True):
        lhs = self._extract_covariates()
        rhs = other._extract_covariates()

        res = []
        if len(rhs) != 0 and len(lhs) != 0:
            # Find the unique products in both lists with matching expression/operator
            combined = lhs.copy()
            if add:
                combined.update(rhs)
                for i in combined.copy():
                    if i[4].option:
                        opposite = list(i)
                        opposite[4] = Option(False)
                        opposite = tuple(opposite)
                        combined.discard(opposite)
            else:
                combined.difference_update(rhs)
                for i in rhs:
                    opposite = list(i)
                    opposite[4] = Option(False if i[4].option else True)
                    opposite = tuple(opposite)
                    if opposite in combined:
                        combined.discard(opposite)
                        combined.discard(i)  # Unnecessary ?
            res = combined
        elif len(rhs) != 0 and len(lhs) == 0 and add:
            res = rhs
        elif len(lhs) != 0 and len(rhs) == 0:
            res = lhs
        else:
            return tuple()

        # Convert all elements to SETS before using reduce
        res = [tuple({x} for x in e) for e in res]
        res = _reduce_covariate(res)
        cov_res = []
        for param, cov, fp, op, opt in res:
            cov_res.append(
                Covariate(tuple(param), tuple(cov), tuple(fp), list(op)[0], list(opt)[0])
            )

        if all(len(c.parameter) == 0 for c in cov_res):
            return tuple()
        else:
            return tuple(cov_res)

    def __eq__(self, other):
        transits = self._eq_transits(other)
        return (
            self.absorption == other.absorption
            and self.elimination == other.elimination
            and transits
            and self.peripherals == other.peripherals
            and self.lagtime == other.lagtime
            and self._eq_covariate(other)
        )

    def _eq_transits(self, other):
        # TODO : Use add helper and check all in "combined"

        lhs_counts_depot = [
            c for t in self.transits if Name("DEPOT") in t.eval.depot for c in t.counts
        ]
        lhs_counts_nodepot = [
            c for t in self.transits if Name("NODEPOT") in t.eval.depot for c in t.counts
        ]
        rhs_counts_depot = [
            c for t in other.transits if Name("DEPOT") in t.eval.depot for c in t.counts
        ]
        rhs_counts_nodepot = [
            c for t in other.transits if Name("NODEPOT") in t.eval.depot for c in t.counts
        ]
        return set(lhs_counts_depot) == set(rhs_counts_depot) and set(lhs_counts_nodepot) == set(
            rhs_counts_nodepot
        )

    def _eq_covariate(self, other):
        lhs = self._extract_covariates()
        rhs = other._extract_covariates()
        # Should OPTIONAL be ignored?
        return all(c in rhs for c in lhs)

    def _extract_peripherals(self):
        peripheral_dict = {"MET": set(), "DRUG": set()}
        for p in self.peripherals:
            for m in p.modes:
                peripheral_dict[m.name] = peripheral_dict[m.name].union(set(p.counts))

        return peripheral_dict

    def _extract_covariates(self):
        lhs = set()
        lhs_ref = []
        for cov in self.covariate:
            if isinstance(cov.parameter, Ref) or isinstance(cov.covariate, Ref):
                lhs_ref.append(cov)
                continue
            else:
                lhs.update(
                    set(
                        product(
                            cov.parameter, cov.covariate, cov.eval().fp, (cov.op,), (cov.optional,)
                        )
                    )
                )

        if len(lhs_ref) != 0:
            raise ValueError(
                'Cannot be performed with reference value. Try using .expand(model) first.'
            )

        return lhs

    def get_number_of_features(self, model=None):
        no_of_features = 0
        for key, attr in vars(self).items():
            if attr is None:
                continue
            if isinstance(attr, tuple):
                if key == '_covariate':
                    no_of_features += sum(feat.get_length(model) for feat in attr)
                else:
                    no_of_features += sum(len(feat) for feat in attr)
            else:
                no_of_features += len(attr)

        return no_of_features


def _add_helper(s1, s2, value_name, join_name):
    s1_join_name_dict = defaultdict(list)
    for s in s1:
        s = s.eval
        attribute_names = getattr(s, join_name)
        attribute_values = getattr(s, value_name)
        for a in attribute_names:
            s1_join_name_dict[a].extend(attribute_values)
    s2_join_name_dict = defaultdict(list)
    for s in s2:
        s = s.eval
        attribute_names = getattr(s, join_name)
        attribute_values = getattr(s, value_name)
        for a in attribute_names:
            s2_join_name_dict[a].extend(attribute_values)
    s1_unique = {
        k: tuple(set(s1_join_name_dict[k]) - set(s2_join_name_dict[k]))
        for k in s1_join_name_dict.keys()
    }
    s2_unique = {
        k: tuple(set(s2_join_name_dict[k]) - set(s1_join_name_dict[k]))
        for k in s2_join_name_dict.keys()
    }
    s2_unique = {k: v for k, v in s2_unique.items() if v}
    s12_joined = {
        k: tuple(set(s1_join_name_dict[k]).intersection(set(s2_join_name_dict[k])))
        for k in s1_join_name_dict.keys()
    }
    s12_joined = {k: v for k, v in s12_joined.items() if v}

    def remove_empty(d):
        return {k: v for k, v in d.items() if v}

    return (remove_empty(s1_unique), remove_empty(s2_unique), remove_empty(s12_joined))


def _reduce_covariate(c):
    c = _reduce(c, 2)
    c = _reduce(c, 1)
    c = _reduce(c, 0)
    return c


def _reduce(s, n):
    """Reduce list of tuples of sets based on the n:th element in each tuple"""
    clean_s = []
    checked_keys = []
    for i in s:
        key = i[:n] + i[n + 1 :]
        if key in checked_keys:
            pass
        else:
            checked_keys.append(key)
            attr_set = set()
            for e in s:
                if key == e[:n] + e[n + 1 :]:
                    attr_set.update(e[n])
            clean_s.append(i[:n] + (attr_set,) + i[n + 1 :])
    return clean_s


[docs] def get_model_features(model: Model, supress_warnings: bool = False) -> str: """Create an MFL representation of an input model Given an input model. Create a model feature language (MFL) string representation. Can currently extract absorption, elimination, transits, peripherals and lagtime. Parameters ---------- model : Model Model to extract features from. supress_warnings : TYPE, optional Choose to supress warnings if absorption/elimination type cannot be determined. The default is False. Returns ------- str A MFL string representation of the input model. """ # ABSORPTION absorption = None if has_seq_zo_fo_absorption(model): absorption = "SEQ-ZO-FO" elif has_zero_order_absorption(model): absorption = "ZO" elif has_first_order_absorption(model): absorption = "FO" elif has_instantaneous_absorption(model): absorption = "INST" if not supress_warnings: if absorption is None: warnings.warn("Could not determine absorption of model.") # ElIMINATION elimination = None if has_mixed_mm_fo_elimination(model): elimination = "MIX-FO-MM" elif has_zero_order_elimination(model): elimination = "ZO" elif has_first_order_elimination(model): elimination = "FO" elif has_michaelis_menten_elimination(model): elimination = "MM" if not supress_warnings: if elimination is None: warnings.warn("Could not determine elimination of model.") # ABSORPTION DELAY (TRANSIT AND LAGTIME) # TRANSITS transits = get_number_of_transit_compartments(model) # TODO : DEPOT if not model.statements.ode_system.find_depot(model.statements): depot = "NODEPOT" else: depot = "DEPOT" lagtime = has_lag_time(model) if not lagtime: lagtime = None # DISTRIBUTION (PERIPHERALS) peripherals = get_number_of_peripheral_compartments(model) # COVARIATES covariates = get_covariate_effects(model) if absorption: absorption = f'ABSORPTION({absorption})' if elimination: elimination = f'ELIMINATION({elimination})' if lagtime: lagtime = "LAGTIME(ON)" if transits != 0: transits = f'TRANSITS({transits}{","+depot})' if peripherals != 0: peripherals = f'PERIPHERALS({peripherals})' if len(covariates) != 0: # FIXME : More extensive cleanup clean_cov = defaultdict(list) for key, value in covariates.items(): clean_cov[(key[1], value[0][0], value[0][1])].append(key[0]) cov_list = [ f'COVARIATE({value},{key[0]},{key[1]},{key[2]})' for key, value in clean_cov.items() ] covariates = ';'.join(cov_list) # Remove quotes from parameter names covariates = covariates.replace("'", '') else: covariates = None # TODO : Implement IIV, PKPD, METABOLITE, TMDD(?) return ";".join( [ e for e in [absorption, elimination, lagtime, transits, peripherals, covariates] if (e is not None and e != 0) ] )