Skip to content

Commit 60b680f

Browse files
authored
Merge pull request #179 from borglab/templated-global-functions
2 parents 82ab5e3 + cb1fcdf commit 60b680f

39 files changed

+826
-223
lines changed

gtwrap/interface_parser/function.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def from_parse_result(parse_result: ParseResults):
8282
return ArgumentList([])
8383

8484
def __repr__(self) -> str:
85-
return ",".join([repr(x) for x in self.args_list])
85+
return ", ".join([repr(x) for x in self.args_list])
8686

8787
def __len__(self) -> int:
8888
return len(self.args_list)
@@ -182,8 +182,7 @@ def __init__(self,
182182
self.args.parent = self
183183

184184
def __repr__(self) -> str:
185-
return "GlobalFunction: {}{}({})".format(self.return_type, self.name,
186-
self.args)
185+
return f"GlobalFunction: {self.name}({self.args}) -> {self.return_type}"
187186

188187
def to_cpp(self) -> str:
189188
"""Generate the C++ code for wrapping."""

gtwrap/interface_parser/template.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ class TypenameAndInstantiations:
3131
"""
3232
Rule to parse the template parameters.
3333
34-
template<typename POSE> // POSE is the Instantiation.
34+
template<typename POSE = {Pose2, Pose3}> // Pos2 and Pose3 are the `Instantiation`s.
3535
"""
3636
rule = (
3737
IDENT("typename") #

gtwrap/interface_parser/type.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ class Typename:
4040
"""
4141

4242
namespaces_name_rule = delimitedList(IDENT, "::")
43-
instantiation_name_rule = delimitedList(IDENT, "::")
4443
rule = (
4544
namespaces_name_rule("namespaces_and_name") #
4645
).setParseAction(lambda t: Typename(t))
@@ -164,7 +163,7 @@ class Type:
164163
"""
165164
rule = (
166165
Optional(CONST("is_const")) #
167-
+ (BasicType.rule("basic") | CustomType.rule("qualified")) # BR
166+
+ (BasicType.rule("basic") | CustomType.rule("custom")) # BR
168167
+ Optional(
169168
SHARED_POINTER("is_shared_ptr") | RAW_POINTER("is_ptr")
170169
| REF("is_ref")) #
@@ -192,9 +191,9 @@ def from_parse_result(t: ParseResults):
192191
is_ref=t.is_ref,
193192
is_basic=True,
194193
)
195-
elif t.qualified:
194+
elif t.custom:
196195
return Type(
197-
typename=t.qualified.typename,
196+
typename=t.custom.typename,
198197
is_const=t.is_const,
199198
is_shared_ptr=t.is_shared_ptr,
200199
is_ptr=t.is_ptr,
@@ -212,6 +211,13 @@ def __repr__(self) -> str:
212211
is_const="const " if self.is_const else "",
213212
is_ptr_or_ref=" " + is_ptr_or_ref if is_ptr_or_ref else "")
214213

214+
def get_typename(self):
215+
"""
216+
Get the typename of this type without any qualifiers.
217+
E.g. for `const gtsam::Pose3& pose` this will return `gtsam::Pose3`.
218+
"""
219+
return self.typename.to_cpp()
220+
215221
def to_cpp(self) -> str:
216222
"""
217223
Generate the C++ code for wrapping.
@@ -221,22 +227,18 @@ def to_cpp(self) -> str:
221227

222228
if self.is_shared_ptr:
223229
typename = "std::shared_ptr<{typename}>".format(
224-
typename=self.typename.to_cpp())
230+
typename=self.get_typename())
225231
elif self.is_ptr:
226232
typename = "{typename}*".format(typename=self.typename.to_cpp())
227233
elif self.is_ref:
228234
typename = typename = "{typename}&".format(
229-
typename=self.typename.to_cpp())
235+
typename=self.get_typename())
230236
else:
231-
typename = self.typename.to_cpp()
237+
typename = self.get_typename()
232238

233239
return ("{const}{typename}".format(
234240
const="const " if self.is_const else "", typename=typename))
235241

236-
def get_typename(self):
237-
"""Convenience method to get the typename of this type."""
238-
return self.typename.name
239-
240242

241243
class TemplatedType:
242244
"""
@@ -283,16 +285,21 @@ def __repr__(self):
283285
return "TemplatedType({typename.namespaces}::{typename.name})".format(
284286
typename=self.typename)
285287

286-
def to_cpp(self):
288+
def get_typename(self):
287289
"""
288-
Generate the C++ code for wrapping.
290+
Get the typename of this type without any qualifiers.
291+
E.g. for `const std::vector<double>& indices` this will return `std::vector<double>`.
289292
"""
290293
# Use Type.to_cpp to do the heavy lifting for the template parameters.
291294
template_args = ", ".join([t.to_cpp() for t in self.template_params])
292295

293-
typename = "{typename}<{template_args}>".format(
294-
typename=self.typename.qualified_name(),
295-
template_args=template_args)
296+
return f"{self.typename.qualified_name()}<{template_args}>"
297+
298+
def to_cpp(self):
299+
"""
300+
Generate the C++ code for wrapping.
301+
"""
302+
typename = self.get_typename()
296303

297304
if self.is_shared_ptr:
298305
typename = f"std::shared_ptr<{typename}>"

gtwrap/matlab_wrapper/mixins.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,15 @@
99
class CheckMixin:
1010
"""Mixin to provide various checks."""
1111
# Data types that are primitive types
12-
not_ptr_type: Tuple = ('int', 'double', 'bool', 'char', 'unsigned char',
13-
'size_t')
12+
not_ptr_type: Tuple = (
13+
"int",
14+
"double",
15+
"bool",
16+
"char",
17+
"unsigned char",
18+
"size_t",
19+
"Key", # This is an alias for a uint64_t
20+
)
1421
# Ignore the namespace for these datatypes
1522
ignore_namespace: Tuple = ('Matrix', 'Vector', 'Point2', 'Point3')
1623
# Methods that should be ignored
@@ -111,6 +118,9 @@ def _format_type_name(self,
111118
is_constructor: bool = False,
112119
is_method: bool = False):
113120
"""
121+
Helper method to get the string version of `type_name` which can go into the wrapper generated C++ code.
122+
This is specific to the semantics of Matlab.
123+
114124
Args:
115125
type_name: an interface_parser.Typename to reformat
116126
separator: the statement to add between namespaces and typename
@@ -133,6 +143,9 @@ def _format_type_name(self,
133143
if name not in self.ignore_namespace and namespace != '':
134144
formatted_type_name += namespace + separator
135145

146+
# Get string representation so we can use as dict key.
147+
name = str(name)
148+
136149
if is_constructor:
137150
formatted_type_name += self.data_type.get(name) or name
138151
elif is_method:

gtwrap/matlab_wrapper/wrapper.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@ def __init__(self,
5353
'Matrix': 'double',
5454
'int': 'numeric',
5555
'size_t': 'numeric',
56+
'Key': 'numeric',
5657
'bool': 'logical'
5758
}
5859
# Map the data type into the type used in Matlab methods.
@@ -68,6 +69,7 @@ def __init__(self,
6869
'Point3': 'double',
6970
'Vector': 'double',
7071
'Matrix': 'double',
72+
'Key': 'numeric',
7173
'bool': 'bool'
7274
}
7375
# The amount of times the wrapper has created a call to geometry_wrapper
@@ -108,7 +110,8 @@ def _update_wrapper_id(self,
108110
109111
Args:
110112
collector_function: tuple storing info about the wrapper function
111-
(namespace, class instance, function name, function object)
113+
(namespace, class/function instance,
114+
type of collector function, method object if class instance)
112115
id_diff: constant to add to the id in the map
113116
function_name: Optional custom function_name.
114117
@@ -372,9 +375,9 @@ def _unwrap_argument(self, arg, arg_id=0, instantiated_class=None):
372375
ctype_sep=ctype_sep, ctype=ctype_camel, id=arg_id)
373376

374377
else:
375-
arg_type = "{ctype}".format(ctype=arg.ctype.typename.name)
378+
arg_type = "{ctype}".format(ctype=self._format_type_name(arg.ctype.typename))
376379
unwrap = 'unwrap< {ctype} >(in[{id}]);'.format(
377-
ctype=arg.ctype.typename.name, id=arg_id)
380+
ctype=self._format_type_name(arg.ctype.typename), id=arg_id)
378381

379382
return arg_type, unwrap
380383

@@ -578,6 +581,7 @@ def wrap_global_function(self, function):
578581
# Get all combinations of parameters
579582
param_wrap = ''
580583

584+
# Iterate through possible overloads of the function
581585
for i, overload in enumerate(function):
582586
param_wrap += ' if' if i == 0 else ' elseif'
583587
param_wrap += ' length(varargin) == '
@@ -1218,7 +1222,7 @@ def wrap_namespace(self, namespace, add_mex_file=True):
12181222
if isinstance(func, parser.GlobalFunction)
12191223
]
12201224

1221-
self.wrap_methods(all_funcs, True, global_ns=namespace)
1225+
self.wrap_methods(all_funcs, global_funcs=True, global_ns=namespace)
12221226

12231227
return wrapped
12241228

@@ -1333,7 +1337,7 @@ def _collector_return(self,
13331337
prefix=' ')
13341338
else:
13351339
expanded += ' out[0] = wrap< {0} >({1});'.format(
1336-
ctype.typename.name, obj)
1340+
self._format_type_name(ctype.typename), obj)
13371341

13381342
return expanded
13391343

@@ -1365,8 +1369,8 @@ def wrap_collector_function_return(self, method, instantiated_class=None):
13651369
method_name += method.original.name
13661370

13671371
elif isinstance(method, parser.GlobalFunction):
1368-
method_name = self._format_global_function(method, '::')
1369-
method_name += method.name
1372+
namespace = self._format_global_function(method, '::')
1373+
method_name = namespace + method.to_cpp()
13701374

13711375
else:
13721376
if isinstance(method.parent, instantiator.InstantiatedClass):
@@ -1624,7 +1628,7 @@ def generate_collector_function(self, func_id):
16241628

16251629
body += self._wrapper_unwrap_arguments(collector_func[1].args)[1]
16261630
body += self.wrap_collector_function_return(
1627-
collector_func[1]) + '\n}\n'
1631+
collector_func[1]) + "\n}\n"
16281632

16291633
collector_function += body
16301634

gtwrap/template_instantiator/function.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class InstantiatedGlobalFunction(parser.GlobalFunction):
1414
template<T = {double}>
1515
T add(const T& x, const T& y);
1616
"""
17+
1718
def __init__(self, original, instantiations=(), new_name=''):
1819
self.original = original
1920
self.instantiations = instantiations
@@ -54,16 +55,14 @@ def __init__(self, original, instantiations=(), new_name=''):
5455
def to_cpp(self):
5556
"""Generate the C++ code for wrapping."""
5657
if self.original.template:
57-
instantiated_names = [
58+
instantiated_params = [
5859
"::".join(inst.namespaces + [inst.instantiated_name()])
5960
for inst in self.instantiations
6061
]
61-
ret = "{}<{}>".format(self.original.name,
62-
",".join(instantiated_names))
62+
ret = f"{self.original.name}<{','.join(instantiated_params)}>"
6363
else:
6464
ret = self.original.name
6565
return ret
6666

6767
def __repr__(self):
68-
return "Instantiated {}".format(
69-
super(InstantiatedGlobalFunction, self).__repr__())
68+
return f"Instantiated {super().__repr__}"

gtwrap/template_instantiator/helpers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,6 @@ def instantiate_type(
6363
ctype.typename.instantiations[idx].name =\
6464
instantiations[template_idx]
6565

66-
6766
str_arg_typename = str(ctype.typename)
6867

6968
# Check if template is a scoped template e.g. T::Value where T is the template
@@ -88,6 +87,7 @@ def instantiate_type(
8887
is_ref=ctype.is_ref,
8988
is_basic=ctype.is_basic,
9089
)
90+
9191
# Check for exact template match.
9292
elif str_arg_typename in template_typenames:
9393
idx = template_typenames.index(str_arg_typename)
@@ -228,6 +228,7 @@ class InstantiationHelper:
228228
parent=parent)
229229
```
230230
"""
231+
231232
def __init__(self, instantiation_type: InstantiatedMembers):
232233
self.instantiation_type = instantiation_type
233234

matlab.h

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,14 @@ mxArray* wrap<int>(const int& value) {
175175
return result;
176176
}
177177

178+
// specialization to gtsam::Key which is an alias for uint64_t
179+
template<>
180+
mxArray* wrap<uint64_t>(const uint64_t& value) {
181+
mxArray *result = scalar(mxUINT32OR64_CLASS);
182+
*(uint64_t*)mxGetData(result) = value;
183+
return result;
184+
}
185+
178186
// specialization to double -> just double
179187
template<>
180188
mxArray* wrap<double>(const double& value) {
@@ -330,6 +338,13 @@ int unwrap<int>(const mxArray* array) {
330338
return myGetScalar<int>(array);
331339
}
332340

341+
// specialization to gtsam::Key which is an alias for uint64_t
342+
template<>
343+
uint64_t unwrap<uint64_t>(const mxArray* array) {
344+
checkScalar(array,"unwrap<uint64_t>");
345+
return myGetScalar<uint64_t>(array);
346+
}
347+
333348
// specialization to size_t
334349
template<>
335350
size_t unwrap<size_t>(const mxArray* array) {
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
function varargout = EliminateDiscrete(varargin)
2+
if length(varargin) == 2 && isa(varargin{1},'gtsam.DiscreteFactorGraph') && isa(varargin{2},'gtsam.Ordering')
3+
[ varargout{1} varargout{2} ] = functions_wrapper(25, varargin{:});
4+
else
5+
error('Arguments do not match any overload of function EliminateDiscrete');
6+
end
7+
end
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
function varargout = FindKarcherMeanPoint3(varargin)
2+
if length(varargin) == 1 && isa(varargin{1},'std.vectorgtsam::Point3')
3+
varargout{1} = functions_wrapper(28, varargin{:});
4+
else
5+
error('Arguments do not match any overload of function FindKarcherMeanPoint3');
6+
end
7+
end

0 commit comments

Comments
 (0)