Source code for cutcutcodec.core.compilation.sympy_to_torch.preprocess

"""Prepare the work for the Printer, decompose and analyse."""

import itertools
import re
from collections import OrderedDict

import sympy
from sympy.core.basic import Basic
from sympy.core.containers import Tuple
from sympy.core.numbers import Float, Integer
from sympy.core.symbol import Symbol
from sympy.logic.boolalg import BooleanFalse, BooleanTrue
from sympy.simplify.cse_main import cse


def _broadcast(
    symb_expr: list[tuple[Symbol, Basic]], shapes: set[frozenset[Symbol]],
) -> dict[Symbol, frozenset[Symbol]]:
    r"""Find the shape of all the sub vars.

    Complexity o(n).

    Parameters
    ----------
    symb_expr : list[tuple[sympy.core.symbol.Symbol, sympy.core.basic.Basic]]
        The list of symbols and atomic expressions.
    shapes : set[frozenset[sympy.core.symbol.Symbol]]
        The initials shapes. For a more complete description, please refer to
        ``cutcutcodec.core.compilation.sympy_to_torch.preprocess.preprocess``.

    Returns
    -------
    shapes : dict[sympy.core.symbol.Symbol, frozenset[sympy.core.symbol.Symbol]]
        All the shapes, for each intermediate vars, associate the broadcast shape of the tensor.

    Examples
    --------
    >>> from pprint import pprint
    >>> from sympy.abc import c, x
    >>> from sympy import Number, Tuple, sin, symbols
    >>> from cutcutcodec.core.compilation.sympy_to_torch.preprocess import _broadcast
    >>> _, _0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11 = symbols("_ _:12")
    >>> tree = [(_1, c**(-2)), (_2, 1/x), (_0, _1*_2), (_3, Number(0)), (_6, sin(x)),
    ...         (_5, sin(_6)), (_4, _5 + 1), (_7, c), (_8, c), (_9, x), (_10, x),
    ...         (_11, _0), (_, Tuple(_3, _7, _8, _9, _10, _0, _11, _4))]
    >>> shapes = _broadcast(tree, set())
    >>> pprint({v: sorted(shapes[v], key=str) for v in sorted(shapes, key=str)}, sort_dicts=False)
    {_0: [c, x],
     _1: [c],
     _10: [x],
     _11: [c, x],
     _2: [x],
     _3: [],
     _4: [x],
     _5: [x],
     _6: [x],
     _7: [c],
     _8: [c],
     _9: [x],
     c: [c],
     x: [x]}
    >>> shapes = _broadcast(tree, {frozenset({c, x})})
    >>> pprint({v: sorted(shapes[v], key=str) for v in sorted(shapes, key=str)}, sort_dicts=False)
    {_0: [c],
     _1: [c],
     _10: [c],
     _11: [c],
     _2: [c],
     _3: [],
     _4: [c],
     _5: [c],
     _6: [c],
     _7: [c],
     _8: [c],
     _9: [c],
     c: [c],
     x: [c]}
    >>>

    """
    # simplification of given shapes, remove single and merge common
    shapes_simplified = []
    for shape in (set(s) for s in shapes if s):
        merge = False
        for shape_ in shapes_simplified:
            if shape & shape_:
                shape_ |= shape  # has to be set, not frozenset for inplace operation
                merge = True
        if not merge:
            shapes_simplified.append(shape)
    shapes_simplified = [frozenset(s) for s in shapes_simplified]

    # parse shapes into dict
    min_of_set = {s: frozenset((min(s, key=str),)) for s in shapes_simplified}
    all_shapes = {symb: min_of_set[s] for s in shapes_simplified for symb in s}

    # exploration of the tree
    for symb, expr in symb_expr:
        if isinstance(expr, Tuple):
            continue
        if (free_symbols := expr.free_symbols):
            all_shapes[symb] = frozenset.union(
                *(all_shapes.get(s, frozenset((s,))) for s in free_symbols),
            )
            for free_symbol in free_symbols:
                all_shapes[free_symbol] = all_shapes.get(free_symbol, frozenset((free_symbol,)))
        else:  # case expr is numbers
            all_shapes[symb] = frozenset()
    return all_shapes


def _expr_to_atomic(expr: Basic, *, _symbols=None) -> list[tuple[Symbol, Basic]]:
    """Apply ``cse`` and split the sub patterns.

    Sum and product expressions can contain more than 2 terms.

    Parameters
    ----------
    expr : sympy.core.basic.Basic
        The sympy expression to split.

    Returns
    -------
    replacements : list of (Symbol, expression) pairs
        All of the common subexpressions that were replaced.
        All subexpressions are atomic.

    Examples
    --------
    >>> from pprint import pprint
    >>> from sympy.abc import c, x
    >>> from sympy import Tuple, sin
    >>> from cutcutcodec.core.compilation.sympy_to_torch.preprocess import _expr_to_atomic
    >>> exp = Tuple(0, c, c, x, x, c**-2/x, c**-2/x, sin(sin(x)) + 1)
    >>> pprint(_expr_to_atomic(exp))
    [(_1, c**(-2)),
     (_2, 1/x),
     (_0, _1*_2),
     (_5, sin(x)),
     (_4, sin(_5)),
     (_3, _4 + 1),
     (_6, 0),
     (_7, c),
     (_8, c),
     (_9, x),
     (_10, x),
     (_11, _0),
     (_, (_6, _7, _8, _9, _10, _0, _11, _3))]
    >>>

    """
    # initialisation
    if _symbols is None:
        _symbols = iter(Symbol(f"_{i}") for i in itertools.count())
        rep, last = cse(expr, symbols=_symbols, order="none", list=False)  # fastest as possible
        rep.append(((Symbol("_") if isinstance(expr, Tuple) else next(_symbols)), last))
    else:  # if cse is already called
        rep = [(next(_symbols), expr)]

    # main
    atom_rep = []
    for var, sub_expr in rep:
        if sub_expr.is_Atom:
            atom_rep.append((var, sub_expr))
            continue
        subs = {}
        for arg in sub_expr.args:
            if (
                arg in subs  # we don't do the same calculs several times
                or arg.is_Atom  # replace if sub expr is not atomic
            ):
                continue
            atom_rep += _expr_to_atomic(arg, _symbols=_symbols)
            subs[arg] = (
                atom_rep.pop(-1)[1] if isinstance(atom_rep[-1][1], Tuple) else atom_rep[-1][0]
            )
        if subs:
            sub_expr = sub_expr.xreplace(subs)
        # make sure no duplicate and no intermediate variables
        if isinstance(sub_expr, Tuple) and str(var) == "_":
            args = []  # this ensures the independence of the output variables
            for arg in sub_expr.args:
                if arg in args or not re.fullmatch(r"_\d+", str(arg)):
                    atom_rep.append((next(_symbols), arg))
                    args.append(atom_rep[-1][0])
                else:
                    args.append(arg)
            sub_expr = Tuple(*args)
        atom_rep.append((var, sub_expr))

    return atom_rep


def _get_args(symb_expr: list[tuple[Symbol, Basic]]) -> tuple[set[Symbol], set[Symbol]]:
    """Search the parameters and islotate wich one are changing inplace.

    Complexity o(n).

    Parameters
    ----------
    symb_expr : list[tuple[sympy.core.symbol.Symbol, sympy.core.basic.Basic]]
        The list of symbols and atomic expressions.

    Returns
    -------
    all_args : set[sympy.core.symbol.Symbol]
        All the input arguments
    args_no_safe : set[sympy.core.symbol.Symbol]
        The subset of arguments that is not read-only.
        These arguments are modified inplace in the function
        If the value of these arguments has to be concerved,
        then a copy of these arguments should be passed to the function.
    alloc : set[sympy.core.symbol.Symbol]
        All the internal sub vars

    Examples
    --------
    >>> from pprint import pprint
    >>> from sympy.abc import c, x
    >>> from sympy import Number, Tuple, sin, symbols
    >>> from cutcutcodec.core.compilation.sympy_to_torch.preprocess import _get_args
    >>> _, _0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11 = symbols("_ _:12")
    >>> tree = [(_1, c**(-2)), (_2, 1/x), (_0, _1*_2), (_3, Number(0)), (_6, sin(x)),
    ...         (_5, sin(_6)), (_4, _5 + 1), (_7, c), (_8, c), (_9, x), (_10, x),
    ...         (_11, _0), (_, Tuple(_3, _7, _8, _9, _10, _0, _11, _4))]
    >>> args, no_safe, alloc = _get_args(tree)
    >>> sorted(args, key=str)
    [c, x]
    >>> sorted(no_safe, key=str)
    []
    >>> sorted(alloc, key=str)
    [_, _0, _1, _10, _11, _2, _3, _4, _5, _6, _7, _8, _9]
    >>>

    """
    all_args, no_safe, alloc = set(), set(), set()
    for symb, expr in symb_expr:
        symbs = expr.free_symbols
        all_args |= symbs - alloc
        if symb in all_args and expr != symb:
            no_safe.add(symb)
        else:
            alloc.add(symb)
    return all_args, no_safe, alloc


def _isolate_cst_dyn(
    symb_expr: list[tuple[Symbol, Basic]], cst_args: set[Symbol],
) -> tuple[list[tuple[Symbol, Basic]], list[tuple[Symbol, Basic]]]:
    """Isolate the constant subexpressions.

    Complexity o(n).

    Parameters
    ----------
    symb_expr : list[tuple[sympy.core.symbol.Symbol, sympy.core.basic.Basic]]
        Returned value of ``_expr_to_atomic``.
    cst_args : set[sympy.core.symbol.Symbol]
        The constants input parameters.
        The subexpressions of this parameters will be cached.

    Returns
    -------
    cst_tree : list[tuple[sympy.core.symbol.Symbol, sympy.core.basic.Basic]]
        The graph to compute the constant sub expressions.
        The last value is a ``sympy.core.containers.Tuple``.
    dyn_tree : list[tuple[sympy.core.symbol.Symbol, sympy.core.basic.Basic]]
        The main tree containing only dynamic expressions.

    Examples
    --------
    >>> from pprint import pprint
    >>> from sympy.abc import c, x
    >>> from sympy import Number, Tuple, sin, symbols
    >>> from cutcutcodec.core.compilation.sympy_to_torch.preprocess import _isolate_cst_dyn
    >>> _, _0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11 = symbols("_ _:12")
    >>> tree = [(_1, c**(-2)), (_2, 1/x), (_0, _1*_2), (_3, Number(0)), (_6, sin(x)),
    ...         (_5, sin(_6)), (_4, _5 + 1), (_7, c), (_8, c), (_9, x), (_10, x),
    ...         (_11, _0), (_, Tuple(_3, _7, _8, _9, _10, _0, _11, _4))]
    >>> cst, dyn = _isolate_cst_dyn(tree, {c})
    >>> pprint(cst)
    [(_1, c**(-2)), (_7, c), (_8, c), (_, (_1, _7, _8))]
    >>> pprint(dyn)
    [(_2, 1/x),
     (_0, _1*_2),
     (_3, 0),
     (_6, sin(x)),
     (_5, sin(_6)),
     (_4, _5 + 1),
     (_9, x),
     (_10, x),
     (_11, _0),
     (_, (_3, _7, _8, _9, _10, _0, _11, _4))]
    >>>

    """
    # detection of cst sub expressions
    csts = set()  # contains all the cst sub symbols
    for symb, expr in symb_expr:
        if (
            not expr.is_number
            and not isinstance(expr, BooleanTrue)
            and not isinstance(expr, BooleanFalse)
            and all(s in cst_args or s in csts for s in expr.free_symbols)
        ):
            csts.add(symb)

    # split the constant and the dynamic sub graphs
    cst_tree = []
    dyn_tree = []
    for symb, expr in symb_expr:
        if symb in csts:  # if the expression is constant
            cst_tree.append((symb, expr))
        else:
            dyn_tree.append((symb, expr))

    # special case all the tree is constant
    if not dyn_tree:
        dyn_tree.append((Symbol("_"), symb_expr[-1][0]))

    # selection of usefull cst symbols
    final_csts = set()
    for symb, expr in dyn_tree:
        for sub_symb in expr.free_symbols:
            if sub_symb in csts and sub_symb not in final_csts:  # we keep the statics parts
                final_csts.add(sub_symb)
    cst_tree.append((Symbol("_"), Tuple(*sorted(final_csts, key=str))))

    return cst_tree, dyn_tree


def _limit_realoc(
    symb_expr: list[tuple[Symbol, Basic]],
    broadcasted_shapes: dict[Symbol, frozenset[Symbol]],
    safe: set[Symbol],
) -> dict[Symbol, set[Symbol]]:
    """Optimises memory by reusing as many old variables as possible.

    Complexity o(n**2).
    The ``sympy.core.containers.Tuple`` expressions are not considered.

    Parameters
    ----------
    symb_expr : list[tuple[sympy.core.symbol.Symbol, sympy.core.basic.Basic]]
        The list of symbols and atomic expressions.
    broadcasted_shapes : dict[sympy.core.symbol.Symbol, frozenset[sympy.core.symbol.Symbol]]
        For each var, associate the broadcasted shape,
        Output of ``cutcutcodec.core.compilation.sympy_to_torch.preprocess._broadcast``.
    safe : set[sympy.core.symbol.Symbol]
        The variables to keep safe. For a more complete description, please refer to
        ``cutcutcodec.core.compilation.sympy_to_torch.preprocess.preprocess``.

    Returns
    -------
    alloc : dict[sympy.core.symbol.Symbol, set[sympy.core.symbol.Symbol]]
        The intermediate variables to be declared and their respective dimensions.
    symb_expr : list[tuple[sympy.core.symbol.Symbol, sympy.core.basic.Basic]]
        The new equivalent tree that minimize the realocation and take care of the shapes.

    Examples
    --------
    >>> from pprint import pprint
    >>> from sympy.abc import c, x
    >>> from sympy import Number, Tuple, sin, symbols
    >>> from cutcutcodec.core.compilation.sympy_to_torch.preprocess import _broadcast, _limit_realoc
    >>> _, _0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11 = symbols("_ _:12")
    >>> tree = [(_1, c**(-2)), (_2, 1/x), (_0, _1*_2), (_3, Number(0)), (_6, sin(x)),
    ...         (_5, sin(_6)), (_4, _5 + 1), (_7, c), (_8, c), (_9, x), (_10, x),
    ...         (_11, _0), (_, Tuple(_3, _7, _8, _9, _10, _0, _11, _4))]
    >>> alloc, tree = _limit_realoc(tree, _broadcast(tree, set()), {c, x})
    >>> pprint({v: sorted(alloc[v], key=str) for v in sorted(alloc, key=str)}, sort_dicts=False)
    {_0: [c, x], _1: [c], _10: [x], _11: [c, x], _2: [x], _3: [], _8: [c], _9: [x]}
    >>> pprint(tree)
    [(_1, c**(-2)),
     (_2, 1/x),
     (_0, _1*_2),
     (_3, 0),
     (_2, sin(x)),
     (_2, sin(_2)),
     (_2, _2 + 1),
     (_1, c),
     (_8, c),
     (_9, x),
     (_10, x),
     (_11, _0),
     (_, (_3, _1, _8, _9, _10, _0, _11, _2))]
    >>> alloc, tree = _limit_realoc(tree, _broadcast(tree, {frozenset({c, x})}), set())
    >>> pprint({v: sorted(alloc[v], key=str) for v in sorted(alloc, key=str)}, sort_dicts=False)
    {_0: [c], _1: [c], _11: [c], _2: [c], _3: [], _9: [c]}
    >>> pprint(tree)
    [(_1, c**(-2)),
     (_2, 1/x),
     (_0, _1*_2),
     (_3, 0),
     (_2, sin(x)),
     (_2, sin(_2)),
     (_2, _2 + 1),
     (_1, c),
     (_9, x),
     (_11, _0),
     (_, (_3, _1, c, _9, x, _0, _11, _2))]
    >>>

    """
    args, _, _ = _get_args(symb_expr)

    # at each step, find the new free sub symbols
    used = [set()]
    for _, expr in reversed(symb_expr):
        used.insert(0, used[0] | expr.free_symbols)
    all_new_free: list[set[Symbol]] = (  # each step, the new free vars o(n**2)
        [(u1-u2)-safe for u1, u2 in zip(used[:-1], used[1:])]
    )

    # replacement line by line
    new_tree: list[tuple[Symbol, Basic]] = []  # the new tree with substitutions
    free: set[Symbol] = set()  # the free symbols at the current step i
    subs: dict[Symbol, Symbol] = {}  # each old name, associate the new one
    for new_free, (old_symb, old_expr) in zip(all_new_free, symb_expr):
        # replace old vars by new
        symb = subs.get(old_symb, old_symb)
        expr = old_expr.xreplace(subs)
        free |= {subs.get(s, s) for s in new_free}
        # particular case
        if isinstance(old_expr, Tuple):  # particular case of tuple, end of tree
            new_tree.append((symb, expr))
            break
        # selection of the new substitution variable
        # to disable W0640, the following code work, but is is safe here
        # symbs = {f for f in free if broadcasted_shapes[f] == broadcasted_shapes[symb]}
        # criteria = {s: ((s != expr), (not str(s).startswith("_")), str(s)) for s in symbs}
        # symb = min(symbs, key=criteria.get, default=symb)
        symb = min(
            {f for f in free if broadcasted_shapes[f] == broadcasted_shapes[symb]},
            key=lambda s: (
                (s != expr), (not str(s).startswith("_")), str(s),  # noqa: B023
            ),
            default=symb,
        )
        if symb != expr:
            new_tree.append((symb, expr))
        # updates the context
        free -= {symb}
        subs[old_symb] = symb
        subs = {o: subs.get(n, n) for o, n in subs.items()}
        broadcasted_shapes = {subs.get(s, s): v for s, v in broadcasted_shapes.items()}

    # search for allocated variables and their size
    alloc = {a: s for a, s in broadcasted_shapes.items() if a not in args}
    return alloc, new_tree


def _rename(
    symb_expr: list[tuple[Symbol, Basic]], subs: dict[Symbol, Symbol], *, return_subs=False,
) -> list[tuple[Symbol, Basic]]:
    """Replace and rename the symbols in canonical order.

    Complexity o(n).

    Parameters
    ----------
    symb_expr : list[tuple[sympy.core.symbol.Symbol, sympy.core.basic.Basic]]
        The list of symbols and atomic expressions.
    subs : dict[sympy.core.symbol.Symbol, sympy.core.symbol.Symbol]
        The replacement name of some symbols.
    return_subs : boolean, defaul=False
        If set to True, return the dictionary of the substitutions.

    Returns
    -------
    new_tree : list[tuple[sympy.core.symbol.Symbol, sympy.core.basic.Basic]]
        The renamed elements.

    Examples
    --------
    >>> from pprint import pprint
    >>> from sympy.abc import c, x
    >>> from sympy import Number, Tuple, sin, symbols
    >>> from cutcutcodec.core.compilation.sympy_to_torch.preprocess import _rename
    >>> _, _0, _1, _2, _3, _4, _5, _6, _7, _8, _9, _10, _11 = symbols("_ _:12")
    >>> tree = [(_1, c**(-2)), (_2, 1/x), (_0, _1*_2), (_3, Number(0)), (_6, sin(x)),
    ...         (_5, sin(_6)), (_4, _5 + 1), (_7, c), (_8, c), (_9, x), (_10, x),
    ...         (_11, _0), (_, Tuple(_3, _7, _8, _9, _10, _0, _11, _4))]
    >>> pprint(_rename(tree, {}))
    [(_0, c**(-2)),
     (_1, 1/x),
     (_2, _0*_1),
     (_3, 0),
     (_4, sin(x)),
     (_5, sin(_4)),
     (_6, _5 + 1),
     (_7, c),
     (_8, c),
     (_9, x),
     (_10, x),
     (_11, _2),
     (_, (_3, _7, _8, _9, _10, _2, _11, _6))]
    >>> tree = [(_1, c**(-2)), (_2, 1/x), (_0, _1*_2), (_3, Number(0)), (_2, sin(x)),
    ...         (_2, sin(_2)), (_2, _2 + 1), (_1, c), (_9, x), (_11, _0),
    ...         (_, Tuple(_3, _1, c, _9, x, _0, _11, _2))]
    >>> pprint(_rename(tree, {}))
    [(_0, c**(-2)),
     (_1, 1/x),
     (_2, _0*_1),
     (_3, 0),
     (_1, sin(x)),
     (_1, sin(_1)),
     (_1, _1 + 1),
     (_0, c),
     (_4, x),
     (_5, _2),
     (_, (_3, _0, c, _4, x, _2, _5, _1))]
    >>>

    """
    subs_local = subs.copy()
    renamed_tree = []
    symbols = iter(Symbol(f"_{i}") for i in itertools.count())

    for symb, expr in symb_expr:
        if symb not in subs_local and re.fullmatch(r"_\d+", str(symb)):
            subs_local[symb] = next(symbols)
        renamed_tree.append((subs_local.get(symb, symb), expr.xreplace(subs_local)))

    if return_subs:
        return subs_local, renamed_tree
    return renamed_tree


[docs] def evalf(expr: Basic, prec: int = 37, simplify: bool = False) -> Basic: """Numerical eval and simplification of the expression. Parameters ---------- expr : sympy.Expr The sympy expression to symplify as numerical evaluable. prec : int, default=37 The number of decimals, to comply with the standards, you can use the following values: * float128 -> 37 * float64 -> 18 * float32 -> 10 simplify : boolean, default=False If set to True, it tries to simplify the expression in order to improve the numerical evaluation. Returns ------- sympy.Expr The quite equivalent expression with floats. Examples -------- >>> import sympy >>> from cutcutcodec.core.compilation.sympy_to_torch.preprocess import evalf >>> evalf(sympy.pi) 3.141592653589793238462643383279502884 >>> evalf(sympy.sin(sympy.sin(1))) 0.7456241416655578888931510704303837921 >>> evalf(sympy.sqrt(2)) 1.414213562373095048801688724209698079 >>> evalf(sympy.sympify("-2.0*x")) -2.0*x >>> evalf(sympy.sympify("(x/(2.0*x+2.0))**100.0")) (x/(2.0*x + 2.0))**100.0 >>> evalf(sympy.sympify("sqrt(x)")) x**0.5 >>> """ assert isinstance(expr, Basic), expr.__class__.__name__ assert isinstance(prec, int), prec.__class__.__name__ assert prec >= 1, prec assert isinstance(simplify, bool), simplify.__class__.__name__ # to numerical if isinstance(expr, Tuple): return Tuple(*map(evalf, expr)) sub = expr.atoms(sympy.Float, sympy.NumberSymbol, sympy.Rational) - expr.atoms(Integer) expr = expr.xreplace({s: s.evalf(n=prec) for s in sub}) expr = expr.evalf(n=prec) sub = {s: round(s) for s in expr.atoms(Float)} # float to int sub = {s: i for s, i in sub.items() if float(s) in {-1.0, 0.0, 1.0}} expr = expr.xreplace(sub) # simplification if not simplify: return expr expr = sympy.rcollect(expr, *sorted(expr.atoms(sympy.Float, sympy.Symbol), key=str)) expr = sympy.trigsimp(expr) expr = sympy.logcombine(expr, force=True) expr = sympy.powsimp(expr, force=True, deep=True) return expr
[docs] def preprocess( expr: Basic, cst_args: set[Symbol], shapes: set[frozenset[Symbol]], safe: set[Symbol], ) -> tuple[list[tuple[Symbol, Basic]], dict[Symbol, set[Symbol]], list[tuple[Symbol, Basic]]]: """Decompose and analyse the expression for the printer. Parameters ---------- expr : sympy.core.basic.Basic The complete sympy expression to compile. cst_args : set[sympy.core.symbol.Symbol], optional Arguments that change infrequently enough to be cached. shapes : set[frozenset[sympy.core.symbol.Symbol]], optional If some parameters have the same shape, it is possible to give this information in order to find a more optimal solution for limited the allocations. It variable represents the set of all tensor subsets with the same shapes. For example, {frozenset({a, b, c}), frozenset({x, y})} means that a, b, and c are the same shape, and x and y as well. safe : set[sympy.core.symbol.Symbol] A subset of arguments that should definitely not be modified in place or returned without a copy. The variables provided to this set are safe. Arguments not present in this set can be reused internally in order to optimize memory. Returns ------- tree : dict cst_args : set[sympy.core.symbol.Symbol] All the inputs arguments required for the cst tree input. cst_alloc : dict[sympy.core.symbol.Symbol, set[sympy.core.symbol.Symbol]] The intermediate variables to be declared and their respective dimensions for cst tree. cst_tree : list[tuple[sympy.core.symbol.Symbol, None | sympy.core.basic.Basic]] Each steps of the cached function. dyn_args : set[sympy.core.symbol.Symbol] All the inputs arguments required for the dyn tree input. dyn_alloc : dict[sympy.core.symbol.Symbol, set[sympy.core.symbol.Symbol]] The intermediate variables to be declared and their respective dimensions for dyn tree. dyn_tree : list[tuple[sympy.core.symbol.Symbol, None | sympy.core.basic.Basic]] Each steps of the main function. Examples -------- >>> from pprint import pprint >>> from sympy.abc import c, x >>> from sympy import Tuple, sin >>> from cutcutcodec.core.compilation.sympy_to_torch.preprocess import preprocess >>> exp = Tuple(0, c, c, x, x, c**-2/x, c**-2/x, sin(sin(x)) + 1) >>> def print_tree(tree): ... print( ... "cst_alloc:", ... { ... s: sorted(tree["cst_alloc"][s], key=str) ... for s in sorted(tree["cst_alloc"], key=str) ... } ... ) ... print("cst_tree:") ... pprint(tree["cst_tree"]) ... print( ... "dyn_alloc:", ... { ... s: sorted(tree["dyn_alloc"][s], key=str) ... for s in sorted(tree["dyn_alloc"], key=str) ... } ... ) ... print("dyn_tree:") ... pprint(tree["dyn_tree"]) ... >>> tree = preprocess(exp, {c}, set(), {c, x}) >>> print_tree(tree) cst_alloc: {_cst_0: [c], _cst_1: [c], _cst_2: [c]} cst_tree: [(_cst_0, c**(-2)), (_cst_1, c), (_cst_2, c), (_, (_cst_0, _cst_1, _cst_2))] dyn_alloc: {_0: [x], _1: [c, x], _2: [], _3: [x], _4: [x], _5: [c, x]} dyn_tree: [(_0, 1/x), (_1, _0*_cst_0), (_0, sin(x)), (_0, sin(_0)), (_0, _0 + 1), (_2, 0), (_3, x), (_4, x), (_5, _1), (_, (_2, _cst_1, _cst_2, _3, _4, _1, _5, _0))] >>> tree = preprocess(exp, set(), {frozenset({c, x})}, set()) >>> print_tree(tree) cst_alloc: {c: [c], x: [c]} cst_tree: [(_, ())] dyn_alloc: {_0: [c], _1: [c], _2: [], _3: [c], _4: [c], _5: [c]} dyn_tree: [(_0, c**(-2)), (_1, 1/x), (_0, _0*_1), (_1, sin(x)), (_1, sin(_1)), (_1, _1 + 1), (_2, 0), (_3, c), (_4, x), (_5, _0), (_, (_2, _3, c, _4, x, _0, _5, _1))] >>> """ assert isinstance(expr, Basic), expr.__class__.__name__ assert isinstance(cst_args, set), cst_args.__class__.__name__ assert all(isinstance(s, Symbol) for s in cst_args), cst_args assert cst_args.issubset(expr.free_symbols), f"{cst_args} not in {expr}" assert isinstance(shapes, set), shapes.__class__.__name__ assert all(isinstance(g, frozenset) for g in shapes), shapes assert all(isinstance(v, Symbol) for g in shapes for v in g), shapes assert isinstance(safe, set), safe assert all(isinstance(s, Symbol) for s in safe), safe # decompose and split atomic_tree = _expr_to_atomic(evalf(expr)) # decompose to atomic steps cst_tree, dyn_tree = _isolate_cst_dyn(atomic_tree, cst_args) # isolate the cachable operations # optimise cst tree names = cst_tree[-1][1] cst_args, _, _ = _get_args(cst_tree[:-1]) cst_alloc, cst_tree = _limit_realoc(cst_tree, _broadcast(cst_tree, shapes), safe=cst_args) names = OrderedDict(zip(names, cst_tree[-1][1])) subs, cst_tree = _rename( cst_tree, { s: Symbol(f"_cst_{i}") for i, s in enumerate(cst_tree[-1][1]) if re.fullmatch(r"_\d+", str(s)) }, return_subs=True, ) names = dict(zip(names, cst_tree[-1][1])) cst_alloc = {subs.get(symb, symb): shape for symb, shape in cst_alloc.items()} # optimize dyn tree dyn_tree = _rename(dyn_tree, names, return_subs=False) dyn_alloc, dyn_tree = _limit_realoc( dyn_tree, _broadcast(cst_tree[:-1]+dyn_tree, shapes), safe=(safe | set(names.values())), ) subs, dyn_tree = _rename(dyn_tree, {n: n for n in names.values()}, return_subs=True) dyn_alloc = {subs.get(symb, symb): shape for symb, shape in dyn_alloc.items()} # analyse parameters and safety dyn_args, _, dyn_alloc_symbs = _get_args(dyn_tree) dyn_alloc = {symb: shape for symb, shape in dyn_alloc.items() if symb in dyn_alloc_symbs} # combine all the informations return { "cst_args": cst_args, "cst_alloc": cst_alloc, "cst_tree": cst_tree, "dyn_args": dyn_args, "dyn_alloc": dyn_alloc, "dyn_tree": dyn_tree, }