Source code for cpmpy.expressions.utils

#!/usr/bin/env python
#-*- coding:utf-8 -*-
##
## utils.py
##
"""
Internal utilities for expression handling.

    =================
    List of functions
    =================
    .. autosummary::
        :nosignatures:

        is_bool
        is_int
        is_num
        is_false_cst
        is_true_cst
        is_boolexpr
        is_pure_list
        is_any_list
        is_transition
        flatlist
        all_pairs
        argval
        argvals
        eval_comparison
        get_bounds     
"""
from __future__ import annotations  # treat annotations lazy (as string)

import cpmpy as cp
import numpy as np
import math
from collections.abc import Iterable  # for flatten
from itertools import combinations
from typing import TYPE_CHECKING, TypeGuard, Optional, overload, Final
from cpmpy.exceptions import IncompleteFunctionError

if TYPE_CHECKING:
    # only import for type checking
    from cpmpy.expressions.core import ExprLike, BoolExprLike, Expression
    from cpmpy.expressions.variables import NDVarArray

NP_TYPES: Final = frozenset({
    np.int8, np.int16, np.int32, np.int64,
    np.uint8, np.uint16, np.uint32, np.uint64,
    np.bool_
})

[docs] def is_bool(arg): """ is it a boolean (incl numpy variants) """ return isinstance(arg, (bool, np.bool_, cp.BoolVal))
[docs] def is_int(arg): """ can it be interpreted as an integer? (incl bool and numpy variants) """ return isinstance(arg, (bool, np.bool_, cp.BoolVal, int, np.integer))
[docs] def is_num(arg): """ is it an int or float? (incl numpy variants) """ return isinstance(arg, (bool, np.bool_, cp.BoolVal, int, np.integer, float, np.floating))
[docs] def is_false_cst(arg): """ is the argument the constant False (can be of type bool, np.bool and BoolVal) """ if arg is False or arg is np.False_: return True elif isinstance(arg, cp.BoolVal): return not arg.value() return False
[docs] def is_true_cst(arg): """ is the argument the constant True (can be of type bool, np.bool and BoolVal) """ if arg is True or arg is np.True_: return True elif isinstance(arg, cp.BoolVal): return arg.value() return False
[docs] def is_boolexpr(expr): """ is the argument a boolean expression or a boolean value """ #boolexpr if hasattr(expr, 'is_bool'): return expr.is_bool() #boolean constant return is_bool(expr)
[docs] def is_pure_list(arg): """ is it a list or tuple? """ return isinstance(arg, (list, tuple))
[docs] def is_any_list(arg) -> TypeGuard[list | tuple | np.ndarray]: """ is it a list or tuple or numpy array? """ return isinstance(arg, (list, tuple, np.ndarray))
[docs] def flatlist(args): """ recursively flatten arguments into one single list """ return list(_flatten(args))
def _flatten(args): """ flattens the irregular nested list into an iterator from: https://stackoverflow.com/questions/2158395/flatten-an-irregular-list-of-lists """ for el in args: if isinstance(el, Iterable) and not isinstance(el, (str, bytes)): yield from _flatten(el) else: yield el
[docs] def all_pairs(args): """ returns all pairwise combinations of elements in args """ return list(combinations(args, 2))
[docs] def argval(a): """ returns .value() of Expression, otherwise the variable itself We check with hasattr instead of isinstance to avoid circular dependency """ if hasattr(a, "value"): try: val = a.value() except IncompleteFunctionError as e: if isinstance(a, cp.expressions.core.Expression) and a.is_bool(): return False else: raise e else: val = a if isinstance(val, np.generic): return val.item() # ensure it is a Python native value return val
[docs] def argvals(arr): if is_any_list(arr): return [argvals(arg) for arg in arr] return argval(arr)
[docs] def argvals_intexpr(lst: Iterable[int|Expression]) -> Optional[list[int]]: """ A well-typed helper function to get the values of a list of int|Expression, or None if any expression is not assigned """ vals: list[int] = [] for e in lst: if isinstance(e, int): vals.append(e) else: # Expression v = e.value() if v is None: return None vals.append(v) return vals
[docs] def eval_comparison(str_op, lhs, rhs): """ Internal function: evaluates the textual `str_op` comparison operator lhs <str_op> rhs Valid str_op's: * '==' * '!=' * '>' * '>=' * '<' * '<=' Especially useful in decomposition and transformation functions that already involve a comparison. """ if isinstance(lhs, (np.integer, np.bool_)): lhs = int(lhs) if isinstance(rhs, (np.integer, np.bool_)): rhs = int(rhs) if str_op == '==': return lhs == rhs elif str_op == '!=': return lhs != rhs elif str_op == '>': return lhs > rhs elif str_op == '>=': return lhs >= rhs elif str_op == '<': return lhs < rhs elif str_op == '<=': return lhs <= rhs else: raise Exception("Not a known comparison:", str_op)
[docs] def get_bounds(expr): """ return the bounds of the expression returns appropriately rounded integers """ # import here to avoid circular import # from cpmpy.expressions.core import Expression # from cpmpy.expressions.variables import cpm_array if isinstance(expr, (cp.expressions.core.Expression, cp.expressions.variables.NDVarArray)): return expr.get_bounds() elif is_any_list(expr): lbs, ubs = zip(*[get_bounds(e) for e in expr]) return list(lbs), list(ubs) else: assert is_num(expr), f"All Expressions should have a get_bounds function, `{expr}`" if is_bool(expr): return int(expr), int(expr) return math.floor(expr), math.ceil(expr)
[docs] def get_bounds_intexpr(lst: Iterable[int|Expression]) -> tuple[list[int], list[int]]: """ A well-typed helper function to get the bounds of a list of int|Expression's """ lbs: list[int] = [] ubs: list[int] = [] for e in lst: if isinstance(e, int): lbs.append(e) ubs.append(e) else: # Expression (lb, ub) = e.get_bounds() lbs.append(lb) ubs.append(ub) return lbs, ubs
# first two are declarations for typing purposes only @overload def implies(expr: NDVarArray, other: BoolExprLike, simplify: bool = False) -> NDVarArray: ... @overload def implies(expr: Expression|bool|np.bool_, other: BoolExprLike, simplify: bool = False) -> Expression: ...
[docs] def implies(expr: NDVarArray|BoolExprLike, other: BoolExprLike, simplify: bool = False) -> NDVarArray|Expression: """Implication constraint: ``self -> other``. Like :func:`~cpmpy.expressions.core.Expression.implies`, but also safe when 'expr' is not an Expression Args: expr (NDVarArray|BoolExprLike): the left-hand-side of the implication other (BoolExprLike): the right-hand-side of the implication simplify (bool): if True, simplify by eliminating True/False constants (might remove expressions & their variables from user-view) Returns: Expression: the implication constraint or a BoolVal if simplified Simplification rules: - Expr -> True :: BoolVal(True) (by expr.implies()) - Expr -> False :: ~Expr (by expr.implies()) - True -> other :: other - False -> other :: BoolVal(True) """ if isinstance(expr, (cp.expressions.core.Expression, cp.expressions.variables.NDVarArray)): # both implement .implies() return expr.implies(other, simplify=simplify) elif is_true_cst(expr): # True -> other :: other if isinstance(other, cp.expressions.core.Expression): return other else: return cp.BoolVal(other) elif is_false_cst(expr): # False -> other :: BoolVal(True) return cp.BoolVal(True) else: raise ValueError(f"implies: expr must be an Expression or a boolean, got {type(expr)}")
# Specific stuff for scheduling constraints
[docs] def get_nonneg_args(args, condition=None): """ Replace arguments with negative lowerbound with their nonnegative counterpart arguments: - args: list of expressions - condition: list of boolean expressions, indicating whether the argument is present or not (e.g., optional tasks) """ if condition is None: condition = [True] * len(args) assert len(args) == len(condition), f"Args and is_present must have the same length but got {len(args)} and {len(condition)}" lbs, ubs = zip(*[get_bounds(arg) for arg in args]) new_args = [] cons = [] for lb, ub, arg, cond in zip(lbs, ubs, args, condition): if lb < 0: if ub >= 0: iv = cp.intvar(0, ub) else: # ub < 0 iv = cp.intvar(0,0) cons.append(implies(cond, arg == iv)) # will always be False if ub < 0 new_args.append(iv) else: new_args.append(arg) return new_args, cons
# Specific stuff for ShortTabel global (should this be in globalconstraints.py instead?) STAR = "*" # define constant here
[docs] def is_star(arg): """ Check if arg is star as used in the ShortTable global constraint """ return isinstance(arg, type(STAR)) and arg == STAR
[docs] def npint2int(iter: Iterable[ExprLike]) -> tuple[int|Expression, ...]: """Convert numpy values in iterable to Python integers, return as tuple.""" return tuple(int(el) if type(el) in NP_TYPES else el for el in iter) # type: ignore # it can't see we're removing the np.integers