cutcutcodec.core.compilation.sympy_to_torch.preprocess.preprocess

cutcutcodec.core.compilation.sympy_to_torch.preprocess.preprocess(expr: Basic, cst_args: set[Symbol], shapes: set[frozenset[Symbol]], safe: set[Symbol]) tuple[list[tuple[Symbol, Basic]], dict[Symbol, set[Symbol]], list[tuple[Symbol, Basic]]][source]

Decompose and analyse the expression for the printer.

Parameters

exprsympy.core.basic.Basic

The complete sympy expression to compile.

cst_argsset[sympy.core.symbol.Symbol], optional

Arguments that change infrequently enough to be cached.

shapesset[frozenset[sympy.core.symbol.Symbol]], optional

If some parameters have the same shape, it is possible to give this information in order to find a more optimal solution for limited the allocations. It variable represents the set of all tensor subsets with the same shapes. For example, {frozenset({a, b, c}), frozenset({x, y})} means that a, b, and c are the same shape, and x and y as well.

safeset[sympy.core.symbol.Symbol]

A subset of arguments that should definitely not be modified in place or returned without a copy. The variables provided to this set are safe. Arguments not present in this set can be reused internally in order to optimize memory.

Returns

treedict
cst_argsset[sympy.core.symbol.Symbol]

All the inputs arguments required for the cst tree input.

cst_allocdict[sympy.core.symbol.Symbol, set[sympy.core.symbol.Symbol]]

The intermediate variables to be declared and their respective dimensions for cst tree.

cst_treelist[tuple[sympy.core.symbol.Symbol, typing.Union[sympy.core.basic.Basic, None]]]

Each steps of the cached function.

dyn_argsset[sympy.core.symbol.Symbol]

All the inputs arguments required for the dyn tree input.

dyn_allocdict[sympy.core.symbol.Symbol, set[sympy.core.symbol.Symbol]]

The intermediate variables to be declared and their respective dimensions for dyn tree.

dyn_treelist[tuple[sympy.core.symbol.Symbol, typing.Union[sympy.core.basic.Basic, None]]]

Each steps of the main function.

Examples

>>> from pprint import pprint
>>> from sympy.abc import c, x
>>> from sympy import Tuple, sin
>>> from cutcutcodec.core.compilation.sympy_to_torch.preprocess import preprocess
>>> exp = Tuple(0, c, c, x, x, c**-2/x, c**-2/x, sin(sin(x)) + 1)
>>> def print_tree(tree):
...     print(
...         "cst_alloc:",
...         {
...             s: sorted(tree["cst_alloc"][s], key=str)
...             for s in sorted(tree["cst_alloc"], key=str)
...         }
...     )
...     print("cst_tree:")
...     pprint(tree["cst_tree"])
...     print(
...         "dyn_alloc:",
...         {
...             s: sorted(tree["dyn_alloc"][s], key=str)
...             for s in sorted(tree["dyn_alloc"], key=str)
...         }
...     )
...     print("dyn_tree:")
...     pprint(tree["dyn_tree"])
...
>>> tree = preprocess(exp, {c}, set(), {c, x})
>>> print_tree(tree)
cst_alloc: {_cst_0: [c], _cst_1: [c], _cst_2: [c]}
cst_tree:
[(_cst_0, c**(-2)), (_cst_1, c), (_cst_2, c), (_, (_cst_0, _cst_1, _cst_2))]
dyn_alloc: {_0: [x], _1: [c, x], _2: [], _3: [x], _4: [x], _5: [c, x]}
dyn_tree:
[(_0, 1/x),
 (_1, _0*_cst_0),
 (_0, sin(x)),
 (_0, sin(_0)),
 (_0, _0 + 1),
 (_2, 0),
 (_3, x),
 (_4, x),
 (_5, _1),
 (_, (_2, _cst_1, _cst_2, _3, _4, _1, _5, _0))]
>>> tree = preprocess(exp, set(), {frozenset({c, x})}, set())
>>> print_tree(tree)
cst_alloc: {c: [c], x: [c]}
cst_tree:
[(_, ())]
dyn_alloc: {_0: [c], _1: [c], _2: [], _3: [c], _4: [c], _5: [c]}
dyn_tree:
[(_0, c**(-2)),
 (_1, 1/x),
 (_0, _0*_1),
 (_1, sin(x)),
 (_1, sin(_1)),
 (_1, _1 + 1),
 (_2, 0),
 (_3, c),
 (_4, x),
 (_5, _0),
 (_, (_2, _3, c, _4, x, _0, _5, _1))]
>>>