cutcutcodec.core.compilation.sympy_to_torch.preprocess.preprocess¶
- cutcutcodec.core.compilation.sympy_to_torch.preprocess.preprocess(expr: Basic, cst_args: set[Symbol], shapes: set[frozenset[Symbol]], safe: set[Symbol]) tuple[list[tuple[Symbol, Basic]], dict[Symbol, set[Symbol]], list[tuple[Symbol, Basic]]][source]¶
Decompose and analyse the expression for the printer.
Parameters¶
- exprsympy.core.basic.Basic
The complete sympy expression to compile.
- cst_argsset[sympy.core.symbol.Symbol], optional
Arguments that change infrequently enough to be cached.
- shapesset[frozenset[sympy.core.symbol.Symbol]], optional
If some parameters have the same shape, it is possible to give this information in order to find a more optimal solution for limited the allocations. It variable represents the set of all tensor subsets with the same shapes. For example, {frozenset({a, b, c}), frozenset({x, y})} means that a, b, and c are the same shape, and x and y as well.
- safeset[sympy.core.symbol.Symbol]
A subset of arguments that should definitely not be modified in place or returned without a copy. The variables provided to this set are safe. Arguments not present in this set can be reused internally in order to optimize memory.
Returns¶
- treedict
- cst_argsset[sympy.core.symbol.Symbol]
All the inputs arguments required for the cst tree input.
- cst_allocdict[sympy.core.symbol.Symbol, set[sympy.core.symbol.Symbol]]
The intermediate variables to be declared and their respective dimensions for cst tree.
- cst_treelist[tuple[sympy.core.symbol.Symbol, None | sympy.core.basic.Basic]]
Each steps of the cached function.
- dyn_argsset[sympy.core.symbol.Symbol]
All the inputs arguments required for the dyn tree input.
- dyn_allocdict[sympy.core.symbol.Symbol, set[sympy.core.symbol.Symbol]]
The intermediate variables to be declared and their respective dimensions for dyn tree.
- dyn_treelist[tuple[sympy.core.symbol.Symbol, None | sympy.core.basic.Basic]]
Each steps of the main function.
Examples¶
>>> from pprint import pprint >>> from sympy.abc import c, x >>> from sympy import Tuple, sin >>> from cutcutcodec.core.compilation.sympy_to_torch.preprocess import preprocess >>> exp = Tuple(0, c, c, x, x, c**-2/x, c**-2/x, sin(sin(x)) + 1) >>> def print_tree(tree): ... print( ... "cst_alloc:", ... { ... s: sorted(tree["cst_alloc"][s], key=str) ... for s in sorted(tree["cst_alloc"], key=str) ... } ... ) ... print("cst_tree:") ... pprint(tree["cst_tree"]) ... print( ... "dyn_alloc:", ... { ... s: sorted(tree["dyn_alloc"][s], key=str) ... for s in sorted(tree["dyn_alloc"], key=str) ... } ... ) ... print("dyn_tree:") ... pprint(tree["dyn_tree"]) ... >>> tree = preprocess(exp, {c}, set(), {c, x}) >>> print_tree(tree) cst_alloc: {_cst_0: [c], _cst_1: [c], _cst_2: [c]} cst_tree: [(_cst_0, c**(-2)), (_cst_1, c), (_cst_2, c), (_, (_cst_0, _cst_1, _cst_2))] dyn_alloc: {_0: [x], _1: [c, x], _2: [], _3: [x], _4: [x], _5: [c, x]} dyn_tree: [(_0, 1/x), (_1, _0*_cst_0), (_0, sin(x)), (_0, sin(_0)), (_0, _0 + 1), (_2, 0), (_3, x), (_4, x), (_5, _1), (_, (_2, _cst_1, _cst_2, _3, _4, _1, _5, _0))] >>> tree = preprocess(exp, set(), {frozenset({c, x})}, set()) >>> print_tree(tree) cst_alloc: {c: [c], x: [c]} cst_tree: [(_, ())] dyn_alloc: {_0: [c], _1: [c], _2: [], _3: [c], _4: [c], _5: [c]} dyn_tree: [(_0, c**(-2)), (_1, 1/x), (_0, _0*_1), (_1, sin(x)), (_1, sin(_1)), (_1, _1 + 1), (_2, 0), (_3, c), (_4, x), (_5, _0), (_, (_2, _3, c, _4, x, _0, _5, _1))] >>>