Skip to content

Commit c37b752

Browse files
Copilotgreenc-FNAL
andcommitted
Add Variant helper and address review comments
- Added variant.py helper from PR Framework-R-D#245 - Modified modulewrap.cpp to recognize Variant wrapper via phlex_callable - Updated adder.py to use Variant helper for type-specific registration - Removed debug print statements from verify_extended.py - Removed commented-out mutex code from modulewrap.cpp - Removed debug message() calls from CMakeLists.txt - Fixed LaTeX syntax in copilot-instructions.md (use Unicode ↔) Co-authored-by: greenc-FNAL <2372949+greenc-FNAL@users.noreply.github.com>
1 parent 0354f82 commit c37b752

6 files changed

Lines changed: 109 additions & 31 deletions

File tree

.github/copilot-instructions.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ All Markdown files must strictly follow these markdownlint rules:
161161
- **C++ Driver**: Provides data streams (e.g., `test/python/driver.cpp`).
162162
- **Jsonnet Config**: Wires the graph (e.g., `test/python/pytypes.jsonnet`).
163163
- **Python Script**: Implements algorithms (e.g., `test/python/test_types.py`).
164-
- **Type Conversion**: `plugins/python/src/modulewrap.cpp` handles C++ $\leftrightarrow$ Python conversion.
164+
- **Type Conversion**: `plugins/python/src/modulewrap.cpp` handles C++ Python conversion.
165165
- **Mechanism**: Uses string comparison of type names (e.g., `"float64]]"`). This is brittle.
166166
- **Requirement**: Ensure converters exist for all types used in tests (e.g., `float`, `double`, `unsigned int`, and their vector equivalents).
167167
- **Warning**: Exact type matches are required. `numpy.float32` != `float`.

plugins/python/src/modulewrap.cpp

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
#include <stdexcept>
1010
#include <vector>
1111

12-
// static std::mutex g_py_mutex;
13-
1412
#define NO_IMPORT_ARRAY
1513
#define PY_ARRAY_UNIQUE_SYMBOL phlex_ARRAY_API
1614
#include <numpy/arrayobject.h>
@@ -109,7 +107,6 @@ namespace {
109107
static_assert(sizeof...(Args) == N, "Argument count mismatch");
110108

111109
PyGILRAII gil;
112-
// std::lock_guard<std::mutex> lock(g_py_mutex);
113110

114111
PyObject* result = PyObject_CallFunctionObjArgs(
115112
(PyObject*)m_callable, lifeline_transform(args.get())..., nullptr);
@@ -132,7 +129,6 @@ namespace {
132129
static_assert(sizeof...(Args) == N, "Argument count mismatch");
133130

134131
PyGILRAII gil;
135-
// std::lock_guard<std::mutex> lock(g_py_mutex);
136132

137133
PyObject* result =
138134
PyObject_CallFunctionObjArgs((PyObject*)m_callable, (PyObject*)args.get()..., nullptr);
@@ -369,7 +365,6 @@ namespace {
369365
static PyObjectPtr vint_to_py(std::shared_ptr<std::vector<int>> const& v)
370366
{
371367
PyGILRAII gil;
372-
// std::lock_guard<std::mutex> lock(g_py_mutex);
373368
if (!v)
374369
return PyObjectPtr();
375370
PyObject* list = PyList_New(v->size());
@@ -392,7 +387,6 @@ namespace {
392387
static PyObjectPtr vuint_to_py(std::shared_ptr<std::vector<unsigned int>> const& v)
393388
{
394389
PyGILRAII gil;
395-
// std::lock_guard<std::mutex> lock(g_py_mutex);
396390
if (!v)
397391
return PyObjectPtr();
398392
PyObject* list = PyList_New(v->size());
@@ -415,7 +409,6 @@ namespace {
415409
static PyObjectPtr vlong_to_py(std::shared_ptr<std::vector<long>> const& v)
416410
{
417411
PyGILRAII gil;
418-
// std::lock_guard<std::mutex> lock(g_py_mutex);
419412
if (!v)
420413
return PyObjectPtr();
421414
PyObject* list = PyList_New(v->size());
@@ -438,7 +431,6 @@ namespace {
438431
static PyObjectPtr vulong_to_py(std::shared_ptr<std::vector<unsigned long>> const& v)
439432
{
440433
PyGILRAII gil;
441-
// std::lock_guard<std::mutex> lock(g_py_mutex);
442434
if (!v)
443435
return PyObjectPtr();
444436
PyObject* list = PyList_New(v->size());
@@ -501,7 +493,6 @@ namespace {
501493
static std::shared_ptr<std::vector<int>> py_to_vint(PyObjectPtr pyobj)
502494
{
503495
PyGILRAII gil;
504-
// std::lock_guard<std::mutex> lock(g_py_mutex);
505496
auto vec = std::make_shared<std::vector<int>>();
506497
PyObject* obj = pyobj.get();
507498

@@ -541,7 +532,6 @@ namespace {
541532
static std::shared_ptr<std::vector<unsigned int>> py_to_vuint(PyObjectPtr pyobj)
542533
{
543534
PyGILRAII gil;
544-
// std::lock_guard<std::mutex> lock(g_py_mutex);
545535
auto vec = std::make_shared<std::vector<unsigned int>>();
546536
PyObject* obj = pyobj.get();
547537

@@ -581,7 +571,6 @@ namespace {
581571
static std::shared_ptr<std::vector<long>> py_to_vlong(PyObjectPtr pyobj)
582572
{
583573
PyGILRAII gil;
584-
// std::lock_guard<std::mutex> lock(g_py_mutex);
585574
auto vec = std::make_shared<std::vector<long>>();
586575
PyObject* obj = pyobj.get();
587576

@@ -621,7 +610,6 @@ namespace {
621610
static std::shared_ptr<std::vector<unsigned long>> py_to_vulong(PyObjectPtr pyobj)
622611
{
623612
PyGILRAII gil;
624-
// std::lock_guard<std::mutex> lock(g_py_mutex);
625613
auto vec = std::make_shared<std::vector<unsigned long>>();
626614
PyObject* obj = pyobj.get();
627615

@@ -661,7 +649,6 @@ namespace {
661649
static std::shared_ptr<std::vector<float>> py_to_vfloat(PyObjectPtr pyobj)
662650
{
663651
PyGILRAII gil;
664-
// std::lock_guard<std::mutex> lock(g_py_mutex);
665652
auto vec = std::make_shared<std::vector<float>>();
666653
PyObject* obj = pyobj.get();
667654

@@ -701,7 +688,6 @@ namespace {
701688
static std::shared_ptr<std::vector<double>> py_to_vdouble(PyObjectPtr pyobj)
702689
{
703690
PyGILRAII gil;
704-
// std::lock_guard<std::mutex> lock(g_py_mutex);
705691
auto vec = std::make_shared<std::vector<double>>();
706692
PyObject* obj = pyobj.get();
707693

@@ -866,8 +852,16 @@ static PyObject* parse_args(PyObject* args,
866852
return nullptr;
867853
}
868854

855+
// special case of Phlex Variant wrapper
856+
PyObject* wrapped_callable = PyObject_GetAttrString(callable, "phlex_callable");
857+
if (wrapped_callable) {
858+
callable = wrapped_callable;
859+
} else {
860+
PyErr_Clear();
861+
Py_INCREF(callable);
862+
}
863+
869864
// no common errors detected; actual registration may have more checks
870-
Py_INCREF(callable);
871865
return callable;
872866
}
873867

test/python/CMakeLists.txt

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,6 @@ list(APPEND ACTIVE_PY_CPHLEX_TESTS py:reduce)
154154
)
155155
list(APPEND ACTIVE_PY_CPHLEX_TESTS py:failure)
156156

157-
message(STATUS "Python_SITELIB: ${Python_SITELIB}")
158-
message(STATUS "Python_SITEARCH: ${Python_SITEARCH}")
159157
set(TEST_PYTHONPATH ${CMAKE_CURRENT_SOURCE_DIR})
160158
# Always add site-packages to PYTHONPATH for tests, as embedded python might
161159
# not find them especially in spack environments where they are in
@@ -171,7 +169,6 @@ list(APPEND ACTIVE_PY_CPHLEX_TESTS py:reduce)
171169
# Keep this for backward compatibility or if it adds something else
172170
endif()
173171
set(TEST_PYTHONPATH ${TEST_PYTHONPATH}:$ENV{PYTHONPATH})
174-
message(STATUS "TEST_PYTHONPATH: ${TEST_PYTHONPATH}")
175172

176173
# "failing" tests for checking error paths
177174
add_test(

test/python/adder.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,33 @@
44
real. It serves as a "Hello, World" equivalent for running Python code.
55
"""
66

7+
from typing import Protocol, TypeVar
78

8-
def add(i: int, j: int) -> int:
9+
from variant import Variant
10+
11+
12+
class AddableProtocol[T](Protocol):
13+
"""Typer bound for any types that can be added."""
14+
15+
def __add__(self, other: T) -> T: # noqa: D105
16+
...
17+
18+
19+
Addable = TypeVar('Addable', bound=AddableProtocol)
20+
21+
22+
def add(i: Addable, j: Addable) -> Addable:
923
"""Add the inputs together and return the sum total.
1024
1125
Use the standard `+` operator to add the two inputs together
1226
to arrive at their total.
1327
1428
Args:
15-
i (int): First input.
16-
j (int): Second input.
29+
i (Number): First input.
30+
j (Number): Second input.
1731
1832
Returns:
19-
int: Sum of the two inputs.
33+
Number: Sum of the two inputs.
2034
2135
Examples:
2236
>>> add(1, 2)
@@ -40,4 +54,5 @@ def PHLEX_REGISTER_ALGORITHMS(m, config):
4054
Returns:
4155
None
4256
"""
43-
m.transform(add, input_family=config["input"], output_products=config["output"])
57+
int_adder = Variant(add, {"i": int, "j": int, "return": int}, "iadd")
58+
m.transform(int_adder, input_family=config["input"], output_products=config["output"])

test/python/variant.py

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
"""Annotation helper for C++ typing variants.
2+
3+
Python algorithms are generic, like C++ templates, but the Phlex registration
4+
process requires a single unique signature. These helpers generate annotated
5+
functions for registration with the proper C++ types.
6+
"""
7+
8+
import copy
9+
from typing import Any, Callable
10+
11+
12+
class Variant:
13+
"""Wrapper to associate custom annotations with a callable.
14+
15+
This class wraps a callable and provides custom ``__annotations__`` and
16+
``__name__`` attributes, allowing the same underlying function or callable
17+
object to be registered multiple times with different type annotations.
18+
19+
By default, the provided callable is kept by reference, but can be cloned
20+
(e.g. for callable instances) if requested.
21+
22+
Phlex will recognize the "phlex_callable" data member, allowing an unwrap
23+
and thus saving an indirection. To detect performance degradation, the
24+
wrapper is not callable by default.
25+
26+
Attributes:
27+
phlex_callable (Callable): The underlying callable (public).
28+
__annotations__ (dict): Type information of arguments and return product.
29+
__name__ (str): The name associated with this variant.
30+
31+
Examples:
32+
>>> def add(i: Number, j: Number) -> Number:
33+
... return i + j
34+
...
35+
>>> int_adder = variant(add, {"i": int, "j": int, "return": int}, "iadd")
36+
"""
37+
38+
def __init__(
39+
self,
40+
f: Callable,
41+
annotations: dict[str, str | type | Any],
42+
name: str,
43+
clone: bool | str = False,
44+
allow_call: bool = False,
45+
):
46+
"""Annotate the callable F.
47+
48+
Args:
49+
f (Callable): Annotable function.
50+
annotations (dict): Type information of arguments and return product.
51+
name (str): Name to assign to this variant.
52+
clone (bool|str): If True (or "deep"), creates a shallow (deep) copy
53+
of the callable.
54+
allow_call (bool): Allow this wrapper to forward to the callable.
55+
"""
56+
if clone == 'deep':
57+
self.phlex_callable = copy.deepcopy(f)
58+
elif clone:
59+
self.phlex_callable = copy.copy(f)
60+
else:
61+
self.phlex_callable = f
62+
self.__annotations__ = annotations
63+
self.__name__ = name
64+
self._allow_call = allow_call
65+
66+
def __call__(self, *args, **kwargs):
67+
"""Raises an error if called directly.
68+
69+
Variant instances should not be called directly. The framework should
70+
extract ``phlex_callable`` instead and call that.
71+
72+
Raises:
73+
AssertionError: To indicate incorrect usage, unless overridden.
74+
"""
75+
assert self._allow_call, (
76+
f"TypedVariant '{self.__name__}' was called directly. "
77+
f"The framework should extract phlex_callable instead."
78+
)
79+
return self.phlex_callable(*args, **kwargs) # type: ignore

test/python/verify_extended.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,5 @@
11
"""Observers to check for various types in tests."""
22

3-
import sys
4-
53

64
class VerifierInt:
75
"""Verify int values."""
@@ -42,7 +40,6 @@ def __init__(self, sum_total: int):
4240

4341
def __call__(self, value: "long") -> None: # type: ignore # noqa: F821
4442
"""Check if value matches expected sum."""
45-
print(f"VerifierLong: value={value}, expected={self._sum_total}")
4643
assert value == self._sum_total
4744

4845

@@ -57,7 +54,6 @@ def __init__(self, sum_total: int):
5754

5855
def __call__(self, value: "unsigned long") -> None: # type: ignore # noqa: F722
5956
"""Check if value matches expected sum."""
60-
print(f"VerifierULong: value={value}, expected={self._sum_total}")
6157
assert value == self._sum_total
6258

6359

@@ -72,7 +68,6 @@ def __init__(self, sum_total: float):
7268

7369
def __call__(self, value: "float") -> None:
7470
"""Check if value matches expected sum."""
75-
sys.stderr.write(f"VerifierFloat: value={value}, expected={self._sum_total}\n")
7671
assert abs(value - self._sum_total) < 1e-5
7772

7873

@@ -87,7 +82,6 @@ def __init__(self, sum_total: float):
8782

8883
def __call__(self, value: "double") -> None: # type: ignore # noqa: F821
8984
"""Check if value matches expected sum."""
90-
print(f"VerifierDouble: value={value}, expected={self._sum_total}")
9185
assert abs(value - self._sum_total) < 1e-5
9286

9387

@@ -102,7 +96,6 @@ def __init__(self, expected: bool):
10296

10397
def __call__(self, value: bool) -> None:
10498
"""Check if value matches expected."""
105-
print(f"VerifierBool: value={value}, expected={self._expected}")
10699
assert value == self._expected
107100

108101

0 commit comments

Comments
 (0)