Code source de cutcutcodec.core.compilation.sympy_to_torch.printer

#!/usr/bin/env python3

"""Static C compilation of atomic sympy expression.

This is enable to write a C code, compile it with gcc,
and import the file as a python function.

It is slowler to initialise than the dynamic version but it is faster to evaluate.
This static evaluation do not support broadcasting.

Implemented functions:

    * sympy.Abs
    * sympy.Add `+`
    * sympy.atan
    * sympy.cbrt
    * sympy.cos
    * sympy.Eq
    * sympy.exp
    * sympy.GreaterThan
    * sympy.LessThan
    * sympy.log
    * sympy.Max
    * sympy.Min
    * sympy.Mul `*`
    * sympy.Piecewise
    * sympy.Pow `/` and `**`
    * sympy.sin
    * sympy.sqrt
    * sympy.StrictGreaterThan
    * sympy.StrictLessThan

Not implemented functions:

    * sympy.acos
    * sympy.acosh
    * sympy.Add `+` and `-`
    * sympy.And
    * sympy.arg
    * sympy.asin
    * sympy.asinh
    * sympy.atan2
    * sympy.atanh
    * sympy.ceiling
    * sympy.cosh
    * sympy.Determinant
    * sympy.erf
    * sympy.floor
    * sympy.HadamardProduct
    * sympy.im
    * sympy.ITE
    * sympy.loggamma
    * sympy.MatAdd
    * sympy.Mod `%`
    * sympy.Ne
    * sympy.Not
    * sympy.Or
    * sympy.re
    * sympy.sign
    * sympy.sinh
    * sympy.tan
    * sympy.tanh
    * sympy.Trace
    * sympy.Tuple
"""

import collections
import logging

from sympy.core.basic import Atom
from sympy.core.basic import Basic
from sympy.core.containers import Tuple
from sympy.core.symbol import Symbol

from cutcutcodec.core.exceptions import CompilationError
from .printer_atom import *  # noqa: W0401


C_TYPES = (
    "float", "double", "long double",
    "float complex", "double complex", "long double complex"
)
NP_TYPES = (
    "NPY_FLOAT", "NPY_DOUBLE", "NPY_LONGDOUBLE", "NPY_CFLOAT", "NPY_CDOUBLE", "NPY_CLONGDOUBLE"
)


def _printer(
    tree: list[tuple[Symbol, None | Basic]],
    alloc: dict[Symbol, set[Symbol]],
    args: set[Symbol],
) -> str:
    """Return the complete C source code of the module with the tree function numpy compatible.

    Parameters
    ----------
    tree : list[tuple[sympy.core.symbol.Symbol, None | sympy.core.basic.Basic]]
        Each steps.
    alloc : dict[sympy.core.symbol.Symbol, set[sympy.core.symbol.Symbol]]
        The intermediate variables to be declared and there respective dimensions.
    args : set[sympy.core.symbol.Symbol]
        All the inputs arguments required for this function.

    Returns
    -------
    code : str
        The complete C source code of the compilable python module.
        The entry point is the ``lambdify`` function.
    """
    out_symbs = list(tree[-1][1]) if isinstance(tree[-1][1], Tuple) else [tree[-1][0]]
    kernel_args = args | set(out_symbs)

    # frozen C / python environement
    header = (
        "#define PY_SSIZE_T_CLEA\n"
        # for detection and management of nestead threads
        "#include <unistd.h>\n"
        "#include <sys/syscall.h>\n"
        "#ifndef SYS_gettid\n"
        "#error 'SYS_gettid unavailable on this system'\n"
        "#endif\n"
        "#define gettid() ((pid_t)syscall(SYS_gettid))\n"
        # classical imports
        "#include <complex.h>\n"
        "#include <math.h>\n"
        "#include <numpy/arrayobject.h>\n"
        "#include <omp.h>\n"
        "#include <Python.h>\n"
        "#include <stdio.h>\n"
        "#undef I"
    )
    python_context = (  # python module integration of the function
        "static PyMethodDef lambdifyMethods[] = {\n"
        '  {"lambdify", py_lambdify, METH_VARARGS, "Function for sympy expr evaluation."},\n'
        "  {NULL, NULL, 0, NULL}\n"
        "};\n"
        "static struct PyModuleDef lambdify = {\n"
        "  PyModuleDef_HEAD_INIT,\n"
        '  "lambdify",\n'
        '  "Autogenerated sympy lambdify module.",\n'
        "  -1,\n"
        "  lambdifyMethods\n"
        "};\n"
        "PyMODINIT_FUNC PyInit_lambdify(void)\n"
        "{\n"
        "  import_array();\n"
        "  return PyModule_Create(&lambdify);\n"
        "};"
    )

    # kernel, define the ``c_lambdify`` functions
    kernels: dict[str, set[str]] = {}  # source code of each function and the context functions
    errors = []
    for c_type in C_TYPES:
        try:
            kernels[c_type] = _printer_kernel(
                tree=tree,
                alloc=set(alloc)-set(out_symbs),
                args=kernel_args,
                c_type=c_type,
            )
        except CompilationError as err:
            errors.append(err)
    if not kernels:
        raise CompilationError(
            "failed to compile the expression for all c types",
            *(arg for err in errors for arg in err.args)
        )

    # kernel parser
    parser = _printer_parser(
        kernel_args=sorted(map(str, kernel_args)),
        out_is_tuple=isinstance(tree[-1][1], Tuple),
        args=args,
        out_symbs=out_symbs,
        valid_c_types=sorted(kernels),
    )

    # final assembly
    return (
        header + "\n\n"
        + "\n\n".join(sorted(func for funcs in kernels.values() for func in funcs)) + "\n\n"
        + parser + "\n\n"
        + python_context
    )


def _printer_kernel(
    tree: list[tuple[Symbol, None | Basic]],
    alloc: set[Symbol],
    args: list[Symbol],
    c_type: str,
) -> set[str]:
    """Return the source code of the C function's kernel.

    Parameters
    ----------
    tree : list[tuple[sympy.core.symbol.Symbol, None | sympy.core.basic.Basic, None]]
        Each steps.
    alloc : set[sympy.core.symbol.Symbol]
        The intermediate variables to be declared.
    args : set[sympy.core.symbol.Symbol]
        All the inputs arguments required for this function.
    c_type : str
        The C type of the numbers. Could be `float` or `double`.

    Returns
    -------
    str
        The C source code of the kernel function, no python context.

    Examples
    --------
    >>> from pprint import pprint
    >>> from sympy.abc import c, x
    >>> from sympy import Number, Tuple, sin, symbols
    >>> import numpy as np
    >>> from cutcutcodec.core.compilation.sympy_to_torch.printer import _printer_kernel
    >>> _, _0, _1, _2, _3, _4, _5 = symbols("_ _:6")
    >>> tree = [(_0, c**(-2)), (_1, 1/x), (_2, _0*_1), (_3, Number(0)), (_1, sin(x)), (_1, sin(_1)),
    ...         (_1, _1 + 1), (_0, c), (_4, x), (_5, _2), (_, Tuple(_3, _0, c, _4, x, _2, _5, _1))]
    >>> alloc = {_0: {c}, _1: {c}, _2: {c}, _3: {c}, _4: {c}, _5: {c}}
    >>> kernel = _printer_kernel(tree, alloc, [_2, _3, c, _4, x, _0, _5, _1], "float")
    >>> for k in sorted(kernel):  # doctest: +ELLIPSIS
    ...     print(k)
    ...
    void lambdify_float(const npy_intp _dim, float *_0, float *_1, ..., float *x) {
      npy_intp _i;
      #pragma omp parallel for simd schedule(static)
      for ( _i = 0; _i < _dim; ++_i ) {
        float _0, _1, _2, _3, _4, _5;
        _0[_i] = 1.0f / c[_i] * c[_i];
        _1[_i] = 1.0f / x[_i];
        _2[_i] = _0[_i] * _1[_i];
        _3[_i] = 0.0f;
        _1[_i] = sinf(x[_i]);
        _1[_i] = sinf(_1[_i]);
        _1[_i] += 1.0f;
        _0[_i] = c[_i];
        _4[_i] = x[_i];
        _5[_i] = _2[_i];
      }
    }
    >>>
    """
    context = set()
    code_lines = []

    # signature
    args_ = [
        "const npy_intp _dim",
        *(f"{c_type} *{a}" for a in sorted(map(str, args))),
    ]
    code_lines.append(f"void lambdify_{c_type.replace(' ', '_')}({', '.join(args_)})" + " {")

    # global declaration
    code_lines.append("  npy_intp _i;")

    # main for declaration
    code_lines.append("  #pragma omp parallel for simd schedule(static)")
    code_lines.append("  for ( _i = 0; _i < _dim; ++_i ) {")

    # expression on one item
    kernel_lines = []
    indexing = collections.defaultdict(lambda: "", {a: "[_i]" for a in args})
    for out, expr in (tree[:-1] if isinstance(tree[-1][1], Tuple) else tree):
        new_context, new_alloc, new_lines = _print_atomic(expr, out, indexing, c_type)
        context |= new_context
        alloc |= new_alloc
        kernel_lines.extend([f"    {line}" for line in new_lines])
    if alloc:
        code_lines.append(f"    {c_type} {', '.join(sorted(map(str, alloc)))};")
    code_lines.extend(kernel_lines)
    code_lines.append("  }")  # close for

    # close main function
    code_lines.append("}")
    return context | {"\n".join(code_lines)}


def _printer_parser(
    kernel_args: list[str],
    out_is_tuple: bool,
    args: set[Symbol],
    out_symbs: list[Symbol],
    valid_c_types: list[str],
) -> str:
    """Write the code of the C python parser.

    Help to ``cutcutcodec.core.sympy_to_torch.printer._printer``.

    Parameters
    ----------
    kernel_args : list[str]
        The ordered list of the input symbol name of the kernels functions.
    out_is_tuple : boolean
        True if the returned value has to be packed inside a tuple.
    args : set[sympy.core.symbol.Symbol]
        All the inputs arguments required for this function.
    out_symbs : list[sympy.core.symbol.Symbol]
        The ordered list of the symbols to return.
    valid_c_types : list[str]
        The c type list where the kernel compilation is successful
    """
    ref_symb = min(args, key=str)

    # function api and variables declaration
    parser = ""
    parser += "static PyObject *py_lambdify(PyObject *self, PyObject *args) {\n"
    parser += "  PyArrayObject " + ", ".join(f"*{a}" for a in sorted(map(str, kernel_args))) + ";\n"
    parser += "  npy_intp _dim;\n"
    if out_is_tuple:
        parser += "  PyObject *_out;\n"

    # parse the input arguments
    args_ = [
        "args",
        ('"' + "O!"*len(args) + '"'),
        *(f"&PyArray_Type, &{a}" for a in sorted(map(str, args)))
    ]
    parser += f"  if ( !PyArg_ParseTuple({', '.join(args_)}) )" + "{\n"
    parser += "    return NULL;\n"
    parser += "  }\n"
    parser += f"  _dim = PyArray_DIM({ref_symb}, 0);\n"

    # allocation of the new tensors
    args_ = sorted(map(str, set(out_symbs)-args))
    for i, new in enumerate(args_):
        parser += (
            f"  {new} = (PyArrayObject *)PyArray_SimpleNew(1, &_dim, PyArray_TYPE({ref_symb}));\n"
        )
        parser += f"  if ( NULL == {new} ) " + "{\n"
        parser += (
            f'    PyErr_SetString(PyExc_RuntimeError, "failed to create a new {new} array");\n'
        )
        for j in range(i):
            parser += f"    Py_DECREF({args_[j]});\n"
        parser += "    return NULL;\n"
        parser += "  }\n"

    # increment ref of out tensors
    for out_ in sorted(set(out_symbs) & args, key=str):
        parser += f"  Py_INCREF((PyObject *){out_});\n"

    # thread management
    parser += "  if (gettid() == getpid()) {\n"
    parser += "    omp_set_num_threads(omp_get_num_procs());\n"
    parser += "  }\n"
    parser += "  else {\n"
    parser += "    omp_set_num_threads(1);\n"
    parser += "  }\n"

    # call the C func
    parser += f"  switch(PyArray_TYPE({ref_symb})) " + "{\n"
    for c_type in valid_c_types:
        parser += f"  case {NP_TYPES[C_TYPES.index(c_type)]}:\n"
        parser += "    Py_BEGIN_ALLOW_THREADS\n"
        args_ = [
            "_dim",
            *(f"({c_type} *)PyArray_DATA({a})" for a in sorted(map(str, kernel_args))),
        ]
        parser += f"    lambdify_{c_type.replace(' ', '_')}({', '.join(args_)});\n"
        parser += "    Py_END_ALLOW_THREADS\n"  # unlock the GIL, allow thread
        if out_is_tuple:
            args_ = [
                str(len(out_symbs)),
                *(f"(PyObject *){o}" for o in out_symbs),
            ]
            parser += f"    _out = PyTuple_Pack({', '.join(args_)});\n"
            for out_ in out_symbs:
                parser += f"    Py_DECREF({out_});\n"
            parser += "    return _out;\n"
        else:
            parser += f"    return (PyObject *){out_symbs[0]};\n"
    parser += "  default:\n"
    parser += '    PyErr_SetString(PyExc_TypeError, "the array type is not supported");\n'
    for out_ in out_symbs:
        parser += f"    Py_DECREF({out_});\n"
    parser += "    return NULL;\n"
    parser += "  }\n"
    parser += "}"

    return parser


def _print_atomic(
    expr: Basic, out: Symbol, indexing: collections.defaultdict[Symbol, str], c_type: str
) -> tuple[set[str], set[Symbol], list[str]]:
    """Write the sympy atomic expression as a valid C source code lines.

    Parameters
    ----------
    expr : sympy.core.basic.Basic
        The sympy atomic expression to eval.
    out : sympy.core.symbol.Symbol
        The variable set, assume that this var is already declared.
    indexing : collections.defaultdict[sympy.core.symbol.Symbol, str]
        The way to acces to the content of the vars, at the current dimension.
        An empty string mean than the var is directely accessible with no indexingation.
        Otherwise, the value corresponds to the content of the braquets.
        For example, "[_i]" means ``var[_i]`` and "[_j - 2]" means ``var[_j - 2]``.
    c_type : str
        The C data type of the variables.

    Returns
    -------
    context : set[str]
        The functions requiered for the returned code lines.
    alloc : set[sympy.core.symbol.Symbol]
        The variables witch need to be declared before this code lines
    code_lines : list[str]
        The C code lignes corresponding to the given expression.

    Raises
    ------
    cutcutcodec.core.exceptions.CompilationError
        If no printer is available for this expression.
    """
    if expr.is_Atom:
        return (
            set(),
            set(),
            [
                f"{atom2str(out, indexing, c_type)} "  # noqa: W0401
                f"= {atom2str(expr, indexing, c_type)};"  # noqa: W0401
            ],
        )
    try:
        func = globals()[f"c_{expr.__class__.__name__.lower()}"]
    except KeyError:  # as err:
        logging.warning("experimental use of sympy.ccode on %s", expr)
        from sympy.printing.codeprinter import ccode  # pylint: disable=C0415
        return set(), set(), [
            # ccode(expr.xreplace({s: IndexedBase(s, 1)[indexing[s]] for s in indexing}),
            ccode(
                expr.xreplace({s: Symbol(f"{s}{indexing[s]}") for s in indexing}),
                assign_to=atom2str(out, indexing, c_type),  # noqa: W0401
                standard="c99",
                contract=False,
                human=True,
            )
        ]
        # raise CompilationError(
        #     f"no function {expr.__class__.__name__} for {expr} in {c_type}"
        # ) from err
    return func(out, indexing, c_type, *expr.args)


[docs] def c_piecewise( out: Symbol, indexing: collections.defaultdict[Symbol, str], c_type: str, *parts: Atom ) -> tuple[set[str], set[Symbol], list[str]]: """C ... if ... else ... operation. Examples -------- >>> from sympy.abc import x >>> from sympy.functions.elementary.piecewise import Piecewise >>> import numpy as np >>> from cutcutcodec.core.compilation.sympy_to_torch.lambdify import _lambdify_c >>> from cutcutcodec.core.compilation.sympy_to_torch.printer import _print_atomic, _printer >>> func = _lambdify_c(_printer([(x, Piecewise((0, x < 0), (x, True)))], {}, {x})) >>> func(np.array([np.nan, -np.inf, -1.0, 0.0, 1.0, np.inf])) array([nan, 0., 0., 0., 1., inf]) >>> """ if len(parts) == 1: code = [ f"{atom2str(out, indexing, c_type)} = " # noqa: W0401 f"{atom2str(parts[0][0], indexing, c_type)};" # noqa: W0401 ] return set(), set(), code if parts[0][1].is_Atom: test, code, alloc = atom2str(parts[0][1], indexing, c_type), [], set() # noqa: W0401 else: buf = Symbol("_buf") _, _, code = _print_atomic(parts[0][1], buf, indexing, c_type) test, alloc = atom2str(buf, indexing, c_type), {buf} # noqa: W0401 code.append(f"if ({test}) {{") # noqa: W0401 code.extend(_print_atomic(parts[0][0], out, indexing, c_type)[2]) code.append("} else {") code.extend(c_piecewise(out, indexing, c_type, *parts[1:])[2]) code.append("}") return set(), alloc, code