#!/usr/bin/env python
#-*- coding:utf-8 -*-
##
## python_builtins.py
##
"""
Overwrites a number of python built-ins, so that they work over variables as expected.
=================
List of functions
=================
.. autosummary::
:nosignatures:
all
any
max
min
sum
abs
"""
import builtins # to use the original Python-builtins
import numpy as np
from .utils import is_false_cst, is_true_cst, is_any_list
from .variables import NDVarArray, cpm_array
from .core import Expression, Operator, BoolVal
from .globalfunctions import Minimum, Maximum, Abs
# Overwriting all/any python built-ins
# all: listwise 'and'
[docs]
def all(iterable):
"""
all() overwrites python built-in,
if iterable contains any `Expression`, then returns an Operator("and", iterable)
otherwise returns whether all of the arguments are true
"""
if isinstance(iterable, NDVarArray): iterable=iterable.flat # 1D iterator
collect = [] # logical expressions
is_expr, return_false = False, False
for elem in iterable:
if isinstance(elem, Expression): # probably most likely case
is_expr = True
if isinstance(elem, BoolVal):
if not elem.args[0]: # False constant
return_false = True
elif elem.is_bool():
collect.append(elem)
else:
raise Exception("Non-Boolean argument '{}' to 'all'".format(elem))
elif is_true_cst(elem):
pass
elif is_false_cst(elem):
return_false = True
elif isinstance(elem, list):
raise Exception("Encountered list in 'all', only accept non-nested lists")
else:
raise Exception("Unexpected argument '{}' to 'all'".format(elem))
if return_false:
return BoolVal(False) if is_expr else False
if len(collect) == 0:
return BoolVal(True) if is_expr else True
if len(collect) == 1:
return collect[0]
if len(collect) >= 2:
return Operator("and", collect)
raise Exception(f"Unepxected collection {collect}")
# any: listwise 'or'
[docs]
def any(iterable):
"""
any() overwrites python built-in,
if iterable contains an `Expression`, then returns an Operator("or", iterable)
otherwise returns whether any of the arguments is true
"""
if isinstance(iterable, NDVarArray): iterable=iterable.flat # 1D iterator
collect = [] # logical expressions
is_expr, return_true = False, False
for elem in iterable:
if isinstance(elem, Expression): # probably most likely case
is_expr = True
if isinstance(elem, BoolVal):
if elem.args[0]: # True constant
return_true = True
elif elem.is_bool():
collect.append(elem)
else:
raise Exception("Non-Boolean argument '{}' to 'all'".format(elem))
elif is_true_cst(elem):
return_true = True
elif is_false_cst(elem):
pass
elif isinstance(elem, list):
raise Exception("Encountered list in 'all', only accept non-nested lists")
else:
raise Exception("Unexpected argument '{}' to 'all'".format(elem))
if return_true:
return BoolVal(True) if is_expr else True
if len(collect) == 0:
return BoolVal(False) if is_expr else False
if len(collect) == 1:
return collect[0]
if len(collect) >= 2:
return Operator("or", collect)
raise Exception(f"Unepxected collection {collect}")
[docs]
def max(*iterable, **kwargs):
"""
max() overwrites the python built-in to support decision variables.
if iterable does not contain CPMpy expressions, the built-in is called
else a Maximum functional global constraint is constructed; no keyword
arguments are supported in that case
"""
if len(iterable) == 1:
iterable = iterable[0] # because of *iterable signature
if isinstance(iterable, np.ndarray):
if iterable.dtype != object or \
not builtins.any(isinstance(elem, (Expression, NDVarArray)) for elem in iterable.flat):
return builtins.max(iterable.flat, **kwargs) # does not contain expressions
else:
iterable = tuple(iterable) # convert iterable (possibly generator) to tuple
if not builtins.any(isinstance(elem, (Expression, NDVarArray)) for elem in iterable):
return builtins.max(iterable, **kwargs) # does not contain expressions
assert len(kwargs)==0, "max over expressions does not support keyword arguments"
return Maximum(iterable)
[docs]
def min(*iterable, **kwargs):
"""
min() overwrites the python built-in to support decision variables.
if iterable does not contain CPMpy expressions, the built-in is called
else a Minimum functional global constraint is constructed; no keyword
arguments are supported in that case
"""
if len(iterable) == 1:
iterable = iterable[0] # because of *iterable signature
if isinstance(iterable, np.ndarray):
if iterable.dtype != object or \
not builtins.any(isinstance(elem, (Expression, NDVarArray)) for elem in iterable.flat):
return builtins.min(iterable.flat, **kwargs) # does not contain expressions
else:
iterable = tuple(iterable) # convert iterable (possibly generator) to tuple
if not builtins.any(isinstance(elem, (Expression, NDVarArray)) for elem in iterable):
return builtins.min(iterable, **kwargs) # does not contain expressions
assert len(kwargs)==0, "min over expressions does not support keyword arguments"
return Minimum(iterable)
[docs]
def sum(iterable, **kwargs):
"""
sum() overwrites the python built-in to support decision variables.
if iterable does not contain CPMpy expressions, the built-in is called
checks if all constants and uses built-in sum() in that case
"""
if isinstance(iterable, np.ndarray):
if iterable.dtype != object or \
not builtins.any(isinstance(elem, (Expression, NDVarArray)) for elem in iterable.flat):
return builtins.sum(iterable.flat, **kwargs) # does not contain expressions
else:
iterable = tuple(iterable) # convert iterable (possibly generator) to tuple
if not builtins.any(isinstance(elem, (Expression, NDVarArray)) for elem in iterable):
return builtins.sum(iterable, **kwargs) # does not contain expressions
assert len(kwargs)==0, "sum over expressions does not support keyword arguments"
return Operator("sum", iterable)
[docs]
def abs(element):
"""
abs() overwrites the python built-in to support decision variables.
if the element given is not a CPMpy expression, the built-in is called
else an Absolute functional global constraint is constructed.
"""
if is_any_list(element): # compat: not allowed by builtins.abs(), but allowed by numpy.abs()
return cpm_array([abs(elem) for elem in element])
if isinstance(element, Expression):
# create global
return Abs(element)
return builtins.abs(element)