#!/usr/bin/env python3
"""Static C compilation of atomic sympy expression.
This is enable to write a C code, compile it with gcc,
and import the file as a python function.
It is slowler to initialise than the dynamic version but it is faster to evaluate.
This static evaluation do not support broadcasting.
Implemented functions:
* sympy.Abs
* sympy.Add `+`
* sympy.atan
* sympy.cbrt
* sympy.cos
* sympy.Eq
* sympy.exp
* sympy.GreaterThan
* sympy.LessThan
* sympy.log
* sympy.Max
* sympy.Min
* sympy.Mul `*`
* sympy.Or
* sympy.Piecewise
* sympy.Pow `/` and `**`
* sympy.sin
* sympy.sqrt
* sympy.StrictGreaterThan
* sympy.StrictLessThan
Not implemented functions:
* sympy.acos
* sympy.acosh
* sympy.Add `+` and `-`
* sympy.And
* sympy.arg
* sympy.asin
* sympy.asinh
* sympy.atan2
* sympy.atanh
* sympy.ceiling
* sympy.cosh
* sympy.Determinant
* sympy.erf
* sympy.floor
* sympy.HadamardProduct
* sympy.im
* sympy.ITE
* sympy.loggamma
* sympy.MatAdd
* sympy.Mod `%`
* sympy.Ne
* sympy.Not
* sympy.re
* sympy.sign
* sympy.sinh
* sympy.tan
* sympy.tanh
* sympy.Trace
* sympy.Tuple
"""
import collections
from sympy.core.basic import Atom
from sympy.core.basic import Basic
from sympy.core.containers import Tuple
from sympy.core.symbol import Symbol
from cutcutcodec.core.exceptions import CompilationError
from .printer_atom import * # noqa: W0401
C_TYPES = (
"float", "double", "long double",
"float complex", "double complex", "long double complex"
)
NP_TYPES = (
"NPY_FLOAT", "NPY_DOUBLE", "NPY_LONGDOUBLE", "NPY_CFLOAT", "NPY_CDOUBLE", "NPY_CLONGDOUBLE"
)
def _printer(
tree: list[tuple[Symbol, None | Basic]],
alloc: dict[Symbol, set[Symbol]],
args: set[Symbol],
) -> str:
"""Return the complete C source code of the module with the tree function numpy compatible.
Parameters
----------
tree : list[tuple[sympy.core.symbol.Symbol, None | sympy.core.basic.Basic]]
Each steps.
alloc : dict[sympy.core.symbol.Symbol, set[sympy.core.symbol.Symbol]]
The intermediate variables to be declared and there respective dimensions.
args : set[sympy.core.symbol.Symbol]
All the inputs arguments required for this function.
Returns
-------
code : str
The complete C source code of the compilable python module.
The entry point is the ``lambdify`` function.
"""
out_symbs = list(tree[-1][1]) if isinstance(tree[-1][1], Tuple) else [tree[-1][0]]
kernel_args = args | set(out_symbs)
# frozen C / python environement
header = (
"#define PY_SSIZE_T_CLEA\n"
# for detection and management of nestead threads
"#include <unistd.h>\n"
"#include <sys/syscall.h>\n"
"#ifndef SYS_gettid\n"
"#error 'SYS_gettid unavailable on this system'\n"
"#endif\n"
"#define gettid() ((pid_t)syscall(SYS_gettid))\n"
# classical imports
"#include <complex.h>\n"
"#include <math.h>\n"
"#include <numpy/arrayobject.h>\n"
"#include <omp.h>\n"
"#include <Python.h>\n"
"#include <stdio.h>\n"
"#undef I"
)
python_context = ( # python module integration of the function
"static PyMethodDef lambdifyMethods[] = {\n"
' {"lambdify", py_lambdify, METH_VARARGS, "Function for sympy expr evaluation."},\n'
" {NULL, NULL, 0, NULL}\n"
"};\n"
"static struct PyModuleDef lambdify = {\n"
" PyModuleDef_HEAD_INIT,\n"
' "lambdify",\n'
' "Autogenerated sympy lambdify module.",\n'
" -1,\n"
" lambdifyMethods\n"
"};\n"
"PyMODINIT_FUNC PyInit_lambdify(void)\n"
"{\n"
" import_array();\n"
" return PyModule_Create(&lambdify);\n"
"};"
)
# kernel, define the ``c_lambdify`` functions
kernels: dict[str, set[str]] = {} # source code of each function and the context functions
errors = []
for c_type in C_TYPES:
try:
kernels[c_type] = _printer_kernel(
tree=tree,
alloc=set(alloc)-set(out_symbs),
args=kernel_args,
c_type=c_type,
)
except CompilationError as err:
errors.append(err)
if not kernels:
raise CompilationError(
"failed to compile the expression for all c types",
*(arg for err in errors for arg in err.args)
)
# kernel parser
parser = _printer_parser(
kernel_args=sorted(map(str, kernel_args)),
out_is_tuple=isinstance(tree[-1][1], Tuple),
args=args,
out_symbs=out_symbs,
valid_c_types=sorted(kernels),
)
# final assembly
return (
header + "\n\n"
+ "\n\n".join(sorted(func for funcs in kernels.values() for func in funcs)) + "\n\n"
+ parser + "\n\n"
+ python_context
)
def _printer_kernel(
tree: list[tuple[Symbol, None | Basic]],
alloc: set[Symbol],
args: list[Symbol],
c_type: str,
) -> set[str]:
"""Return the source code of the C function's kernel.
Parameters
----------
tree : list[tuple[sympy.core.symbol.Symbol, None | sympy.core.basic.Basic, None]]
Each steps.
alloc : set[sympy.core.symbol.Symbol]
The intermediate variables to be declared.
args : set[sympy.core.symbol.Symbol]
All the inputs arguments required for this function.
c_type : str
The C type of the numbers. Could be `float` or `double`.
Returns
-------
str
The C source code of the kernel function, no python context.
Examples
--------
>>> from pprint import pprint
>>> from sympy.abc import c, x
>>> from sympy import Number, Tuple, sin, symbols
>>> import numpy as np
>>> from cutcutcodec.core.compilation.sympy_to_torch.printer import _printer_kernel
>>> _, _0, _1, _2, _3, _4, _5 = symbols("_ _:6")
>>> tree = [(_0, c**(-2)), (_1, 1/x), (_2, _0*_1), (_3, Number(0)), (_1, sin(x)), (_1, sin(_1)),
... (_1, _1 + 1), (_0, c), (_4, x), (_5, _2), (_, Tuple(_3, _0, c, _4, x, _2, _5, _1))]
>>> alloc = {_0: {c}, _1: {c}, _2: {c}, _3: {c}, _4: {c}, _5: {c}}
>>> kernel = _printer_kernel(tree, alloc, [_2, _3, c, _4, x, _0, _5, _1], "float")
>>> for k in sorted(kernel): # doctest: +ELLIPSIS
... print(k)
...
void lambdify_float(const npy_intp _dim, float *_0, float *_1, ..., float *x) {
npy_intp _i;
#pragma omp parallel for simd schedule(static)
for ( _i = 0; _i < _dim; ++_i ) {
float _0, _1, _2, _3, _4, _5;
_0[_i] = 1.0f / c[_i] * c[_i];
_1[_i] = 1.0f / x[_i];
_2[_i] = _0[_i] * _1[_i];
_3[_i] = 0.0f;
_1[_i] = sinf(x[_i]);
_1[_i] = sinf(_1[_i]);
_1[_i] += 1.0f;
_0[_i] = c[_i];
_4[_i] = x[_i];
_5[_i] = _2[_i];
}
}
>>>
"""
context = set()
code_lines = []
# signature
args_ = [
"const npy_intp _dim",
*(f"{c_type} *{a}" for a in sorted(map(str, args))),
]
code_lines.append(f"void lambdify_{c_type.replace(' ', '_')}({', '.join(args_)})" + " {")
# global declaration
code_lines.append(" npy_intp _i;")
# main for declaration
code_lines.append(" #pragma omp parallel for simd schedule(static)")
code_lines.append(" for ( _i = 0; _i < _dim; ++_i ) {")
# expression on one item
kernel_lines = []
indexing = collections.defaultdict(lambda: "", {a: "[_i]" for a in args})
for out, expr in (tree[:-1] if isinstance(tree[-1][1], Tuple) else tree):
new_context, new_alloc, new_lines = _print_atomic(expr, out, indexing, c_type)
context |= new_context
alloc |= new_alloc
kernel_lines.extend([f" {line}" for line in new_lines])
if alloc:
code_lines.append(f" {c_type} {', '.join(sorted(map(str, alloc)))};")
code_lines.extend(kernel_lines)
code_lines.append(" }") # close for
# close main function
code_lines.append("}")
return context | {"\n".join(code_lines)}
def _printer_parser(
kernel_args: list[str],
out_is_tuple: bool,
args: set[Symbol],
out_symbs: list[Symbol],
valid_c_types: list[str],
) -> str:
"""Write the code of the C python parser.
Help to ``cutcutcodec.core.sympy_to_torch.printer._printer``.
Parameters
----------
kernel_args : list[str]
The ordered list of the input symbol name of the kernels functions.
out_is_tuple : boolean
True if the returned value has to be packed inside a tuple.
args : set[sympy.core.symbol.Symbol]
All the inputs arguments required for this function.
out_symbs : list[sympy.core.symbol.Symbol]
The ordered list of the symbols to return.
valid_c_types : list[str]
The c type list where the kernel compilation is successful
"""
ref_symb = min(args, key=str)
# function api and variables declaration
parser = ""
parser += "static PyObject *py_lambdify(PyObject *self, PyObject *args) {\n"
parser += " PyArrayObject " + ", ".join(f"*{a}" for a in sorted(map(str, kernel_args))) + ";\n"
parser += " npy_intp _dim;\n"
if out_is_tuple:
parser += " PyObject *_out;\n"
# parse the input arguments
args_ = [
"args",
('"' + "O!"*len(args) + '"'),
*(f"&PyArray_Type, &{a}" for a in sorted(map(str, args)))
]
parser += f" if ( !PyArg_ParseTuple({', '.join(args_)}) )" + "{\n"
parser += " return NULL;\n"
parser += " }\n"
parser += f" _dim = PyArray_DIM({ref_symb}, 0);\n"
# allocation of the new tensors
args_ = sorted(map(str, set(out_symbs)-args))
for i, new in enumerate(args_):
parser += (
f" {new} = (PyArrayObject *)PyArray_SimpleNew(1, &_dim, PyArray_TYPE({ref_symb}));\n"
)
parser += f" if ( NULL == {new} ) " + "{\n"
parser += (
f' PyErr_SetString(PyExc_RuntimeError, "failed to create a new {new} array");\n'
)
for j in range(i):
parser += f" Py_DECREF({args_[j]});\n"
parser += " return NULL;\n"
parser += " }\n"
# increment ref of out tensors
for out_ in sorted(set(out_symbs) & args, key=str):
parser += f" Py_INCREF((PyObject *){out_});\n"
# thread management
parser += " if (gettid() == getpid()) {\n"
parser += " omp_set_num_threads(omp_get_num_procs());\n"
parser += " }\n"
parser += " else {\n"
parser += " omp_set_num_threads(1);\n"
parser += " }\n"
# call the C func
parser += f" switch(PyArray_TYPE({ref_symb})) " + "{\n"
for c_type in valid_c_types:
parser += f" case {NP_TYPES[C_TYPES.index(c_type)]}:\n"
parser += " Py_BEGIN_ALLOW_THREADS\n"
args_ = [
"_dim",
*(f"({c_type} *)PyArray_DATA({a})" for a in sorted(map(str, kernel_args))),
]
parser += f" lambdify_{c_type.replace(' ', '_')}({', '.join(args_)});\n"
parser += " Py_END_ALLOW_THREADS\n" # unlock the GIL, allow thread
if out_is_tuple:
args_ = [
str(len(out_symbs)),
*(f"(PyObject *){o}" for o in out_symbs),
]
parser += f" _out = PyTuple_Pack({', '.join(args_)});\n"
for out_ in out_symbs:
parser += f" Py_DECREF({out_});\n"
parser += " return _out;\n"
else:
parser += f" return (PyObject *){out_symbs[0]};\n"
parser += " default:\n"
parser += ' PyErr_SetString(PyExc_TypeError, "the array type is not supported");\n'
for out_ in out_symbs:
parser += f" Py_DECREF({out_});\n"
parser += " return NULL;\n"
parser += " }\n"
parser += "}"
return parser
def _print_atomic(
expr: Basic, out: Symbol, indexing: collections.defaultdict[Symbol, str], c_type: str
) -> tuple[set[str], set[Symbol], list[str]]:
"""Write the sympy atomic expression as a valid C source code lines.
Parameters
----------
expr : sympy.core.basic.Basic
The sympy atomic expression to eval.
out : sympy.core.symbol.Symbol
The variable set, assume that this var is already declared.
indexing : collections.defaultdict[sympy.core.symbol.Symbol, str]
The way to acces to the content of the vars, at the current dimension.
An empty string mean than the var is directely accessible with no indexingation.
Otherwise, the value corresponds to the content of the braquets.
For example, "[_i]" means ``var[_i]`` and "[_j - 2]" means ``var[_j - 2]``.
c_type : str
The C data type of the variables.
Returns
-------
context : set[str]
The functions requiered for the returned code lines.
alloc : set[sympy.core.symbol.Symbol]
The variables witch need to be declared before this code lines
code_lines : list[str]
The C code lignes corresponding to the given expression.
Raises
------
cutcutcodec.core.exceptions.CompilationError
If no printer is available for this expression.
"""
if expr.is_Atom:
return (
set(),
set(),
[
f"{atom2str(out, indexing, c_type)} " # noqa: W0401
f"= {atom2str(expr, indexing, c_type)};" # noqa: W0401
],
)
try:
func = globals()[f"c_{expr.__class__.__name__.lower()}"]
except KeyError: # as err:
from sympy.printing.codeprinter import ccode
return set(), set(), [
# ccode(expr.xreplace({s: IndexedBase(s, 1)[indexing[s]] for s in indexing}),
ccode(
expr.xreplace({s: Symbol(f"{s}{indexing[s]}") for s in indexing}),
assign_to=atom2str(out, indexing, c_type), # noqa: W0401
standard="c99",
contract=False,
human=True,
)
]
# raise CompilationError(
# f"no function {expr.__class__.__name__} for {expr} in {c_type}"
# ) from err
return func(out, indexing, c_type, *expr.args)
[docs]
def c_piecewise(
out: Symbol, indexing: collections.defaultdict[Symbol, str], c_type: str, *parts: Atom
) -> tuple[set[str], set[Symbol], list[str]]:
"""C ... if ... else ... operation.
Examples
--------
>>> from sympy.abc import x
>>> from sympy.functions.elementary.piecewise import Piecewise
>>> import numpy as np
>>> from cutcutcodec.core.compilation.sympy_to_torch.lambdify import _lambdify_c
>>> from cutcutcodec.core.compilation.sympy_to_torch.printer import _print_atomic, _printer
>>> func = _lambdify_c(_printer([(x, Piecewise((0, x < 0), (x, True)))], {}, {x}))
>>> func(np.array([np.nan, -np.inf, -1.0, 0.0, 1.0, np.inf]))
array([nan, 0., 0., 0., 1., inf])
>>>
"""
if len(parts) == 1:
code = [
f"{atom2str(out, indexing, c_type)} = " # noqa: W0401
f"{atom2str(parts[0][0], indexing, c_type)};" # noqa: W0401
]
return set(), set(), code
if parts[0][1].is_Atom:
test, code, alloc = atom2str(parts[0][1], indexing, c_type), [], set() # noqa: W0401
else:
buf = Symbol("_buf")
_, _, code = _print_atomic(parts[0][1], buf, indexing, c_type)
test, alloc = atom2str(buf, indexing, c_type), {buf} # noqa: W0401
code.append(f"if ({test}) {{") # noqa: W0401
code.extend(_print_atomic(parts[0][0], out, indexing, c_type)[2])
code.append("} else {")
code.extend(c_piecewise(out, indexing, c_type, *parts[1:])[2])
code.append("}")
return set(), alloc, code