Skip to content

Commit c7d74be

Browse files
committed
Default T=ExprPtr for deserialization
1 parent 170de33 commit c7d74be

17 files changed

+174
-194
lines changed

SeQuant/core/io/serialization/serialization.hpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,10 @@ struct SerializationOptions {
5959
};
6060

6161
#define SEQUANT_DECLARE_DESERIALIZATION_FUNC \
62-
template <typename T> \
62+
template <typename T = ExprPtr> \
6363
T from_string(std::string_view input, \
6464
const DeserializationOptions &options = {}) = delete; \
65-
template <typename T> \
65+
template <typename T = ExprPtr> \
6666
T from_string(std::wstring_view input, \
6767
const DeserializationOptions &options = {}) = delete;
6868

SeQuant/core/io/shorthands.hpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
/// of simple wrapper functions directly in the sequant namespace which delegate
66
/// to the respective functions in the sequant::io namespace.
77

8+
#include <SeQuant/core/expr_fwd.hpp>
89
#include <SeQuant/core/io/concepts.hpp>
910
#include <SeQuant/core/io/latex/latex.hpp>
1011
#include <SeQuant/core/io/serialization/serialization.hpp>
@@ -29,15 +30,15 @@ decltype(auto) serialize(
2930
}
3031

3132
/// Shorthand for io::serialization::from_string
32-
template <typename T>
33+
template <typename T = ExprPtr>
3334
requires(io::deserializable<T>)
3435
decltype(auto) deserialize(
3536
std::string_view input,
3637
const io::serialization::DeserializationOptions &options = {}) {
3738
return io::serialization::from_string<T>(input, options);
3839
}
3940

40-
template <typename T>
41+
template <typename T = ExprPtr>
4142
requires(io::deserializable<T>)
4243
decltype(auto) deserialize(
4344
std::wstring_view input,

tests/integration/eval/btas/scf_btas.hpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,25 +37,24 @@ class SequantEvalScfBTAS final : public SequantEvalScf {
3737

3838
Tensor_t const& f_vo() const {
3939
static Tensor_t tnsr = data_world_(
40-
deserialize<ExprPtr>(L"f{a1;i1}", {.def_perm_symm = Symmetry::Nonsymm})
40+
deserialize(L"f{a1;i1}", {.def_perm_symm = Symmetry::Nonsymm})
4141
->as<Tensor>());
4242
return tnsr;
4343
}
4444

4545
Tensor_t const& g_vvoo() const {
46-
static Tensor_t tnsr =
47-
data_world_(deserialize<ExprPtr>(L"g{a1,a2;i1,i2}",
48-
{.def_perm_symm = Symmetry::Nonsymm})
49-
->as<Tensor>());
46+
static Tensor_t tnsr = data_world_(
47+
deserialize(L"g{a1,a2;i1,i2}", {.def_perm_symm = Symmetry::Nonsymm})
48+
->as<Tensor>());
5049
return tnsr;
5150
}
5251

5352
double energy_spin_orbital() {
5453
static const std::wstring_view energy_expr =
5554
L"f{i1;a1} * t{a1;i1} + g{i1,i2;a1,a2} * "
5655
L"(1/4 * t{a1,a2;i1,i2} + 1/2 t{a1;i1} * t{a2;i2})";
57-
static auto const node = binarize<EvalExprBTAS>(deserialize<ExprPtr>(
58-
energy_expr, {.def_perm_symm = Symmetry::Antisymm}));
56+
static auto const node = binarize<EvalExprBTAS>(
57+
deserialize(energy_expr, {.def_perm_symm = Symmetry::Antisymm}));
5958

6059
return evaluate(node, data_world_)->template get<double>();
6160
}

tests/integration/eval/ta/scf_ta.hpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -34,25 +34,24 @@ class SequantEvalScfTA final : public SequantEvalScf {
3434

3535
Tensor_t const& f_vo() const {
3636
static Tensor_t tnsr = data_world_(
37-
deserialize<ExprPtr>(L"f{a1;i1}", {.def_perm_symm = Symmetry::Nonsymm})
37+
deserialize(L"f{a1;i1}", {.def_perm_symm = Symmetry::Nonsymm})
3838
->as<Tensor>());
3939
return tnsr;
4040
}
4141

4242
Tensor_t const& g_vvoo() const {
43-
static Tensor_t tnsr =
44-
data_world_(deserialize<ExprPtr>(L"g{a1,a2;i1,i2}",
45-
{.def_perm_symm = Symmetry::Nonsymm})
46-
->as<Tensor>());
43+
static Tensor_t tnsr = data_world_(
44+
deserialize(L"g{a1,a2;i1,i2}", {.def_perm_symm = Symmetry::Nonsymm})
45+
->as<Tensor>());
4746
return tnsr;
4847
}
4948

5049
double energy_spin_orbital() {
5150
static const std::wstring_view energy_expr =
5251
L"f{i1;a1} * t{a1;i1} + g{i1,i2;a1,a2} * "
5352
L"(1/4 * t{a1,a2;i1,i2} + 1/2 t{a1;i1} * t{a2;i2})";
54-
static auto const node = binarize<EvalExprTA>(deserialize<ExprPtr>(
55-
energy_expr, {.def_perm_symm = Symmetry::Antisymm}));
53+
static auto const node = binarize<EvalExprTA>(
54+
deserialize(energy_expr, {.def_perm_symm = Symmetry::Antisymm}));
5655

5756
return evaluate(node, data_world_)->template get<double>();
5857
}

tests/unit/test_biorthogonalization.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ TEST_CASE("biorthogonalization", "[Biorthogonalization]") {
3939
for (std::size_t i = 0; i < inputs.size(); ++i) {
4040
CAPTURE(i);
4141

42-
ExprPtr input_expr = deserialize<ExprPtr>(inputs.at(i));
42+
ExprPtr input_expr = deserialize(inputs.at(i));
4343

4444
auto externals = external_indices(input_expr);
4545

tests/unit/test_canonicalize.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -138,12 +138,12 @@ TEST_CASE("canonicalization", "[algorithms]") {
138138
// the context of a sum external index labels are meaningful and should be
139139
// accounted.
140140
for (auto ignore_named_index_labels : {true, false}) {
141-
auto input1 = deserialize<ExprPtr>(
142-
L"1/2 t{a3,a1,a2;i4,i5,i2}:N-C-S g{i4,i5;i3,i1}:N-C-S");
143-
// auto input1 = deserialize<ExprPtr>(L"1/2
141+
auto input1 =
142+
deserialize(L"1/2 t{a3,a1,a2;i4,i5,i2}:N-C-S g{i4,i5;i3,i1}:N-C-S");
143+
// auto input1 = deserialize(L"1/2
144144
// t{a1,a2,a3;i5,i2,i4}:N-C-S g{i4,i5;i3,i1}:N-C-S");
145-
auto input2 = deserialize<ExprPtr>(
146-
L"1/2 t{a1,a3,a2;i5,i4,i2}:N-C-S g{i5,i4;i1,i3}:N-C-S");
145+
auto input2 =
146+
deserialize(L"1/2 t{a1,a3,a2;i5,i4,i2}:N-C-S g{i5,i4;i1,i3}:N-C-S");
147147
canonicalize(
148148
input1,
149149
{.method = CanonicalizationMethod::Topological,

tests/unit/test_eval_expr.cpp

Lines changed: 37 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -24,11 +24,11 @@ namespace sequant {
2424
Tensor parse_tensor(
2525
std::wstring_view tnsr,
2626
const io::serialization::DeserializationOptions& options = {}) {
27-
return deserialize<ExprPtr>(tnsr, options)->as<Tensor>();
27+
return deserialize(tnsr, options)->as<Tensor>();
2828
}
2929

3030
Constant parse_constant(std::wstring_view c) {
31-
return deserialize<ExprPtr>(c)->as<Constant>();
31+
return deserialize(c)->as<Constant>();
3232
}
3333

3434
EvalExpr result_expr(EvalExpr const& left, EvalExpr const& right, EvalOp op) {
@@ -52,7 +52,7 @@ TEST_CASE("eval_expr", "[EvalExpr]") {
5252

5353
REQUIRE_NOTHROW(EvalExpr{t1});
5454

55-
auto p1 = deserialize<ExprPtr>(L"g_{i3,a1}^{i1,i2} * t_{a2}^{a3}");
55+
auto p1 = deserialize(L"g_{i3,a1}^{i1,i2} * t_{a2}^{a3}");
5656

5757
const auto& c2 = EvalExpr{p1->at(0)->as<Tensor>()};
5858
const auto& c3 = EvalExpr{p1->at(1)->as<Tensor>()};
@@ -69,12 +69,12 @@ TEST_CASE("eval_expr", "[EvalExpr]") {
6969

7070
REQUIRE(!x1.op_type());
7171

72-
auto p1 = deserialize<ExprPtr>(L"g_{i3,a1}^{i1,i2} * t_{a2}^{a3}");
72+
auto p1 = deserialize(L"g_{i3,a1}^{i1,i2} * t_{a2}^{a3}");
7373

7474
const auto& c2 = EvalExpr{p1->at(0)->as<Tensor>()};
7575
const auto& c3 = EvalExpr{p1->at(1)->as<Tensor>()};
7676

77-
auto x2 = EvalExpr(deserialize<ExprPtr>(L"1/2")->as<Constant>());
77+
auto x2 = EvalExpr(deserialize(L"1/2")->as<Constant>());
7878
REQUIRE(!x2.op_type());
7979

8080
REQUIRE(!EvalExpr{Variable{L"λ"}}.op_type());
@@ -131,12 +131,12 @@ TEST_CASE("eval_expr", "[EvalExpr]") {
131131
}
132132

133133
SECTION("result expr") {
134-
ExprPtr expr = deserialize<ExprPtr>(L"2 var");
134+
ExprPtr expr = deserialize(L"2 var");
135135
ExprPtr root_expr = binarize(expr)->expr();
136136
REQUIRE(root_expr->is<Variable>());
137137
REQUIRE(*root_expr != *expr);
138138

139-
expr = deserialize<ExprPtr>(L"2 t{a1;i1}");
139+
expr = deserialize(L"2 t{a1;i1}");
140140
root_expr = binarize(expr)->expr();
141141
REQUIRE(root_expr->is<Tensor>());
142142
REQUIRE(*root_expr != *expr);
@@ -186,9 +186,9 @@ TEST_CASE("eval_expr", "[EvalExpr]") {
186186
SECTION("Sequant expression") {
187187
const auto& str_t1 = L"g_{a1,a2}^{a3,a4}";
188188
const auto& str_t2 = L"t_{a3,a4}^{i1,i2}";
189-
const auto& t1 = deserialize<ExprPtr>(str_t1);
189+
const auto& t1 = deserialize(str_t1);
190190

191-
const auto& t2 = deserialize<ExprPtr>(str_t2);
191+
const auto& t2 = deserialize(str_t2);
192192

193193
const auto& x1 = EvalExpr{t1->as<Tensor>()};
194194
const auto& x2 = EvalExpr{t2->as<Tensor>()};
@@ -220,8 +220,7 @@ TEST_CASE("eval_expr", "[EvalExpr]") {
220220
const auto& x45 = result_expr(EvalExpr{t4}, EvalExpr{t5}, EvalOp::Product);
221221
const auto& x54 = result_expr(EvalExpr{t5}, EvalExpr{t4}, EvalOp::Product);
222222

223-
REQUIRE(x45.to_latex() ==
224-
deserialize<ExprPtr>(L"I_{a1,a2}^{i1,i2}")->to_latex());
223+
REQUIRE(x45.to_latex() == deserialize(L"I_{a1,a2}^{i1,i2}")->to_latex());
225224
REQUIRE(x45.to_latex() == x54.to_latex());
226225
}
227226

@@ -246,8 +245,8 @@ TEST_CASE("eval_expr", "[EvalExpr]") {
246245

247246
REQUIRE_FALSE(x1.hash_value() == x3.hash_value());
248247
REQUIRE_FALSE(x12.hash_value() == x3.hash_value());
249-
auto tree1 = binarize(deserialize<ExprPtr>(L"A C"));
250-
auto tree2 = binarize(deserialize<ExprPtr>(L"A t{a1;i1}"));
248+
auto tree1 = binarize(deserialize(L"A C"));
249+
auto tree2 = binarize(deserialize(L"A t{a1;i1}"));
251250

252251
REQUIRE(tree1->hash_value() != tree2->hash_value());
253252
}
@@ -266,12 +265,12 @@ TEST_CASE("eval_expr", "[EvalExpr]") {
266265
REQUIRE(x12.expr()->as<Tensor>().symmetry() == Symmetry::Nonsymm);
267266

268267
// whole bra <-> ket contraction between two symmetric tensors
269-
const auto t3 = deserialize<ExprPtr>(L"g_{i3,i4}^{i1,i2}",
270-
{.def_perm_symm = Symmetry::Symm})
271-
->as<Tensor>();
272-
const auto t4 = deserialize<ExprPtr>(L"t_{a1,a2}^{i3,i4}",
273-
{.def_perm_symm = Symmetry::Symm})
274-
->as<Tensor>();
268+
const auto t3 =
269+
deserialize(L"g_{i3,i4}^{i1,i2}", {.def_perm_symm = Symmetry::Symm})
270+
->as<Tensor>();
271+
const auto t4 =
272+
deserialize(L"t_{a1,a2}^{i3,i4}", {.def_perm_symm = Symmetry::Symm})
273+
->as<Tensor>();
275274

276275
const auto x34 = result_expr(EvalExpr{t3}, EvalExpr{t4}, EvalOp::Product);
277276

@@ -280,12 +279,12 @@ TEST_CASE("eval_expr", "[EvalExpr]") {
280279
REQUIRE(x34.expr()->as<Tensor>().symmetry() == Symmetry::Nonsymm);
281280

282281
// outer product of the same tensor
283-
const auto t5 = deserialize<ExprPtr>(L"f_{i1}^{a1}",
284-
{.def_perm_symm = Symmetry::Nonsymm})
285-
->as<Tensor>();
286-
const auto t6 = deserialize<ExprPtr>(L"f_{i2}^{a2}",
287-
{.def_perm_symm = Symmetry::Nonsymm})
288-
->as<Tensor>();
282+
const auto t5 =
283+
deserialize(L"f_{i1}^{a1}", {.def_perm_symm = Symmetry::Nonsymm})
284+
->as<Tensor>();
285+
const auto t6 =
286+
deserialize(L"f_{i2}^{a2}", {.def_perm_symm = Symmetry::Nonsymm})
287+
->as<Tensor>();
289288

290289
const auto& x56 = result_expr(EvalExpr{t5}, EvalExpr{t6}, EvalOp::Product);
291290

@@ -303,12 +302,12 @@ TEST_CASE("eval_expr", "[EvalExpr]") {
303302
REQUIRE(x78.expr()->as<Tensor>().symmetry() == Symmetry::Nonsymm);
304303

305304
// whole bra <-> ket contraction between symmetric and antisymmetric tensors
306-
auto const t9 = deserialize<ExprPtr>(L"g_{a1,a2}^{a3,a4}",
307-
{.def_perm_symm = Symmetry::Antisymm})
308-
->as<Tensor>();
309-
auto const t10 = deserialize<ExprPtr>(L"t_{a3,a4}^{i1,i2}",
310-
{.def_perm_symm = Symmetry::Symm})
311-
->as<Tensor>();
305+
auto const t9 =
306+
deserialize(L"g_{a1,a2}^{a3,a4}", {.def_perm_symm = Symmetry::Antisymm})
307+
->as<Tensor>();
308+
auto const t10 =
309+
deserialize(L"t_{a3,a4}^{i1,i2}", {.def_perm_symm = Symmetry::Symm})
310+
->as<Tensor>();
312311
auto const x910 = result_expr(EvalExpr{t9}, EvalExpr{t10}, EvalOp::Product);
313312
// todo:
314313
// REQUIRE(x910.expr()->as<Tensor>().symmetry() == Symmetry::Symm);
@@ -318,7 +317,7 @@ TEST_CASE("eval_expr", "[EvalExpr]") {
318317
#if 0
319318
SECTION("Symmetry of sum") {
320319
auto tensor = [](Symmetry s) {
321-
return deserialize<ExprPtr>(L"I_{i1,i2}^{a1,a2}", s)->as<Tensor>();
320+
return deserialize(L"I_{i1,i2}^{a1,a2}", s)->as<Tensor>();
322321
};
323322

324323
auto symmetry = [](const EvalExpr& x) {
@@ -359,14 +358,12 @@ TEST_CASE("eval_expr", "[EvalExpr]") {
359358
#endif
360359

361360
SECTION("Debug") {
362-
auto t1 =
363-
EvalExpr{deserialize<ExprPtr>(L"O{a_1<i_1,i_2>;a_1<i_3,i_2>}",
364-
{.def_perm_symm = Symmetry::Nonsymm})
365-
->as<Tensor>()};
366-
auto t2 =
367-
EvalExpr{deserialize<ExprPtr>(L"O{a_2<i_1,i_2>;a_2<i_3,i_2>}",
368-
{.def_perm_symm = Symmetry::Nonsymm})
369-
->as<Tensor>()};
361+
auto t1 = EvalExpr{deserialize(L"O{a_1<i_1,i_2>;a_1<i_3,i_2>}",
362+
{.def_perm_symm = Symmetry::Nonsymm})
363+
->as<Tensor>()};
364+
auto t2 = EvalExpr{deserialize(L"O{a_2<i_1,i_2>;a_2<i_3,i_2>}",
365+
{.def_perm_symm = Symmetry::Nonsymm})
366+
->as<Tensor>()};
370367

371368
REQUIRE_NOTHROW(result_expr(t1, t2, EvalOp::Product));
372369
}

0 commit comments

Comments
 (0)