Skip to content

Commit a6b8101

Browse files
committed
EOD checkpoint
1 parent 91b3835 commit a6b8101

File tree

3 files changed

+236
-13
lines changed

3 files changed

+236
-13
lines changed

mdio/dataset.h

Lines changed: 169 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#define MDIO_API_VERSION "1.0.0"
1818

19+
#include <algorithm>
1920
#include <cstddef>
2021
#include <fstream>
2122
#include <limits>
@@ -536,6 +537,12 @@ class Dataset {
536537
Result<Dataset> isel(Descriptors&... descriptors) {
537538
VariableCollection vars;
538539

540+
std::cout << "isel forwarded descriptors..." << std::endl;
541+
((std::cout << "Descriptor: " << descriptors.label.label() << " "
542+
<< descriptors.start << " " << descriptors.stop << " "
543+
<< descriptors.step << std::endl), ...);
544+
std::cout << "================================================" << std::endl;
545+
539546
// the shape of the new domain
540547
std::map<std::string, tensorstore::IndexDomainDimension<>> dims;
541548
std::vector<std::string> keys = variables.get_iterable_accessor();
@@ -611,6 +618,103 @@ class Dataset {
611618
return isel(slices[I]...);
612619
}
613620

621+
/// Merge overlapping or adjacent descriptors with the same label & step.
622+
// std::vector<RangeDescriptor<Index>> merge_adjacent(std::vector<RangeDescriptor<Index>> descs) {
623+
// // 1) Bucket by (label, step)
624+
// using Key = std::pair<std::string,Index>;
625+
// std::map<Key, std::vector<RangeDescriptor<Index>>> buckets;
626+
// for (auto &d : descs) {
627+
// Key k = std::make_pair(std::string(d.label.label()), d.step);
628+
// buckets[k].push_back(d);
629+
// }
630+
631+
// std::vector<RangeDescriptor<Index>> result;
632+
// result.reserve(descs.size());
633+
634+
// // 2) For each bucket, sort & merge
635+
// for (auto &kv : buckets) {
636+
// const auto &[label, step] = kv.first;
637+
// auto &vec = kv.second;
638+
639+
// std::sort(vec.begin(), vec.end(),
640+
// [](auto const &a, auto const &b) {
641+
// return a.start < b.start;
642+
// });
643+
644+
// // 3) Sweep through and merge
645+
// Index cur_start = vec[0].start;
646+
// Index cur_stop = vec[0].stop;
647+
648+
// for (size_t i = 1; i < vec.size(); ++i) {
649+
// if (vec[i].start <= cur_stop) {
650+
// // overlap or adjacent
651+
// cur_stop = std::max(cur_stop, vec[i].stop);
652+
// } else {
653+
// // emit the completed run
654+
// std::string lab = label;
655+
// result.push_back({std::move(lab), cur_start, cur_stop, step});
656+
// // start a new one
657+
// cur_start = vec[i].start;
658+
// cur_stop = vec[i].stop;
659+
// }
660+
// }
661+
// std::string lab = label;
662+
// // emit final run for this bucket
663+
// result.push_back({std::move(lab), cur_start, cur_stop, step});
664+
// }
665+
666+
// return result;
667+
// }
668+
std::vector<RangeDescriptor<Index>> merge_adjacent(
669+
std::vector<RangeDescriptor<Index>> descs) {
670+
// 1) bucket by (label, step) using a string_view key
671+
using Key = std::pair<std::string_view, Index>;
672+
std::map<Key, std::vector<RangeDescriptor<Index>>> buckets;
673+
674+
for (auto &d : descs) {
675+
buckets[{ d.label.label(), d.step }].push_back(d);
676+
}
677+
678+
std::vector<RangeDescriptor<Index>> result;
679+
result.reserve(descs.size());
680+
681+
// 2) sort & merge each bucket
682+
for (auto &kv : buckets) {
683+
auto &vec = kv.second;
684+
std::sort(vec.begin(), vec.end(),
685+
[](auto const &a, auto const &b) {
686+
return a.start < b.start;
687+
});
688+
689+
Index cur_start = vec[0].start;
690+
Index cur_stop = vec[0].stop;
691+
692+
for (size_t i = 1; i < vec.size(); ++i) {
693+
if (vec[i].start <= cur_stop) {
694+
cur_stop = std::max(cur_stop, vec[i].stop);
695+
} else {
696+
// copy the *original* descriptor (with a safe label)
697+
auto run = vec[0];
698+
run.start = cur_start;
699+
run.stop = cur_stop;
700+
result.push_back(std::move(run));
701+
702+
cur_start = vec[i].start;
703+
cur_stop = vec[i].stop;
704+
}
705+
}
706+
707+
// emit the last run
708+
auto run = vec[0];
709+
run.start = cur_start;
710+
run.stop = cur_stop;
711+
result.push_back(std::move(run));
712+
}
713+
714+
return result;
715+
}
716+
717+
614718
// Wrapper function that generates the index sequence
615719
/**
616720
* @brief This version of isel is only expected to be used interally.
@@ -621,30 +725,74 @@ class Dataset {
621725
* number of descriptors.
622726
*/
623727
Result<Dataset> isel(const std::vector<RangeDescriptor<Index>>& slices) {
728+
729+
/*
730+
What I need to do:
731+
If there is a disjoint dimension coordinate, I need pass only those to the Variable slice method.
732+
I can use recursion to handle this.
733+
I will get all of the same labeled slices and perform that slice.
734+
I will then pass the remaining slices to the recursive isel call.
735+
I think this will fix my issues.
736+
*/
737+
624738
if (slices.empty()) {
625739
return absl::InvalidArgumentError("No slices provided.");
626740
}
627741

628-
if (slices.size() > internal::kMaxNumSlices) {
629-
std::size_t halfElements = slices.size() / 2;
742+
auto reducedSlices = merge_adjacent(slices);
743+
744+
bool do_simple_slice = true;
745+
746+
std::set<std::string> labels;
747+
if (reducedSlices.size() < 1) {
748+
labels.insert(reducedSlices[0].label.label());
749+
for (auto i=1; i<reducedSlices.size(); i++) {
750+
if (labels.count(reducedSlices[i].label.label()) > 0) {
751+
do_simple_slice = false;
752+
break;
753+
}
754+
labels.insert(reducedSlices[i].label.label());
755+
}
756+
} else {
757+
return absl::InvalidArgumentError("No slices provided.");
758+
}
759+
760+
if (!do_simple_slice) {
761+
std::cout << "Handling multi-dimensional slices..." << std::endl;
762+
763+
}
764+
765+
std::cout << "Reduced slices: " << std::endl;
766+
for (auto &slice : reducedSlices) {
767+
std::cout << "[" << slice.label.label() << ", " << slice.start << ", " << slice.stop << ", " << slice.step << "]" << std::endl;
768+
}
769+
770+
if (reducedSlices.size() > internal::kMaxNumSlices) {
771+
std::cout << "Recursively slicing the dataset..." << std::endl;
772+
std::size_t halfElements = reducedSlices.size() / 2;
630773
if (halfElements % 2 != 0) {
631774
halfElements += 1;
632775
}
633-
std::vector<RangeDescriptor<Index>> firstHalf(slices.begin(), slices.begin() + halfElements);
634-
std::vector<RangeDescriptor<Index>> secondHalf(slices.begin() + halfElements, slices.end());
776+
std::vector<RangeDescriptor<Index>> firstHalf(reducedSlices.begin(), reducedSlices.begin() + halfElements);
777+
std::vector<RangeDescriptor<Index>> secondHalf(reducedSlices.begin() + halfElements, reducedSlices.end());
635778
MDIO_ASSIGN_OR_RETURN(auto ds, isel(static_cast<const std::vector<RangeDescriptor<Index>>&>(firstHalf)));
636779
return ds.isel(static_cast<const std::vector<RangeDescriptor<Index>>&>(secondHalf));
637780
}
638781

639-
std::vector<RangeDescriptor<Index>> slicesCopy = slices;
640-
for (int i = slices.size(); i <= internal::kMaxNumSlices; i++) {
641-
slicesCopy.emplace_back(
642-
RangeDescriptor<Index>({internal::kInertSliceKey, 0, 1, 1}));
643-
}
782+
if (do_simple_slice) {
783+
std::vector<RangeDescriptor<Index>> slicesCopy = reducedSlices;
784+
for (int i = reducedSlices.size(); i <= internal::kMaxNumSlices; i++) {
785+
slicesCopy.emplace_back(
786+
RangeDescriptor<Index>({internal::kInertSliceKey, 0, 1, 1}));
787+
}
644788

645-
// Generate the index sequence and call the implementation
646-
return call_isel_with_vector_impl(
647-
slicesCopy, std::make_index_sequence<internal::kMaxNumSlices>{});
789+
// Generate the index sequence and call the implementation
790+
return call_isel_with_vector_impl(
791+
slicesCopy, std::make_index_sequence<internal::kMaxNumSlices>{});
792+
} else {
793+
std::vector<RangeDescriptor<Index>> slicesCopy;
794+
for (int i=0; i<)
795+
}
648796
}
649797

650798
/**
@@ -1032,9 +1180,11 @@ class Dataset {
10321180
std::vector<RangeDescriptor<Index>> elementwiseSlices;
10331181
for (Index idx = offset; idx < offset + nSamples; ++idx) {
10341182
if (data_ptr[idx] == coord_desc.value) {
1183+
// std::cout << "Found value at index: " << idx << std::endl;
10351184
indices.push_back(idx);
10361185
for (const auto& pos : currentPos) {
1037-
elementwiseSlices.emplace_back(RangeDescriptor<Index>({coord_desc.label.label(), pos.inclusive_min, pos.inclusive_min+1, 1}));
1186+
elementwiseSlices.emplace_back(RangeDescriptor<Index>({pos.label, pos.inclusive_min, pos.inclusive_min+1, 1}));
1187+
// std::cout << pos << std::endl;
10381188
}
10391189
}
10401190
this->_current_position_increment<T>(currentPos, interval);
@@ -1048,6 +1198,12 @@ class Dataset {
10481198

10491199
// TODO(BrianMichell): Coalesce the slices into fewer descriptors.
10501200

1201+
std::cout << "All RangeDescriptors: " << std::endl;
1202+
for (const auto& slice : elementwiseSlices) {
1203+
// std::cout << slice << std::endl;
1204+
std::cout << "[" << slice.label << ", " << slice.start << ", " << slice.stop << ", " << slice.step << "]" << std::endl;
1205+
}
1206+
10511207
MDIO_ASSIGN_OR_RETURN(auto ds, isel(static_cast<const std::vector<RangeDescriptor<Index>>&>(elementwiseSlices)));
10521208
// TODO(BrianMichell): Make this method more async friendly.
10531209
return tensorstore::ReadyFuture<Dataset>(std::move(ds));

mdio/dataset_test.cc

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -733,6 +733,62 @@ TEST(Dataset, where3) {
733733

734734
}
735735

736+
TEST(Dataset, where4) {
737+
std::string path = "zarrs/selTester.mdio";
738+
auto json_vars = GetToyExample();
739+
auto dsRes = mdio::Dataset::from_json(json_vars, path, mdio::constants::kCreateClean);
740+
ASSERT_TRUE(dsRes.status().ok()) << dsRes.status();
741+
auto ds = dsRes.value();
742+
743+
auto cdpXVarRes = ds.variables.get<mdio::dtypes::float32_t>("cdp-x");
744+
ASSERT_TRUE(cdpXVarRes.status().ok()) << cdpXVarRes.status();
745+
auto cdpXVar = cdpXVarRes.value();
746+
auto cdpXDataFut = cdpXVar.Read();
747+
ASSERT_TRUE(cdpXDataFut.status().ok()) << cdpXDataFut.status();
748+
auto cdpXData = cdpXDataFut.value();
749+
auto cdpXDataAccessor = cdpXData.get_data_accessor();
750+
for (auto i = 0; i < cdpXVar.num_samples(); i++) {
751+
cdpXDataAccessor.data()[i+cdpXData.get_flattened_offset()] = 1000.0f;
752+
}
753+
cdpXDataAccessor({15, 7}) = 10.0f;
754+
cdpXDataAccessor({15, 8}) = 10.0f;
755+
cdpXDataAccessor({15, 9}) = 10.0f;
756+
cdpXDataAccessor({117, 5}) = 10.0f;
757+
// cdpXDataAccessor({117, 10}) = 10.0f;
758+
759+
auto writeFut = cdpXVar.Write(cdpXData);
760+
ASSERT_TRUE(writeFut.status().ok()) << writeFut.status();
761+
auto sliceFut = ds.where(mdio::ValueDescriptor<mdio::dtypes::float32_t>({"cdp-x", 10.0f}));
762+
ASSERT_TRUE(sliceFut.status().ok()) << sliceFut.status();
763+
auto slicedDs = sliceFut.value();
764+
std::cout << slicedDs << std::endl;
765+
766+
std::set<std::string> varNames = {"cdp-x", "inline", "crossline", "depth", "image"};
767+
768+
for (auto &varName : varNames) {
769+
std::cout << "================" << varName << "================" << std::endl;
770+
auto cdps = slicedDs.variables.at(varName).value();
771+
772+
std::cout << cdps << std::endl;
773+
774+
std::cout << cdps.get_spec().value().dump(4) << std::endl;
775+
776+
std::cout << cdps.num_samples() << std::endl;
777+
}
778+
779+
auto vRes = slicedDs.variables.get<mdio::dtypes::float32_t>("cdp-x");
780+
ASSERT_TRUE(vRes.status().ok()) << vRes.status();
781+
auto v = vRes.value();
782+
auto vD = v.Read();
783+
ASSERT_TRUE(vD.status().ok()) << vD.status();
784+
auto vda = vD.value().get_data_accessor().data();
785+
for (auto i = 0; i < v.num_samples(); i++) {
786+
std::cout << "[" << i << "]: " << vda[i+vD.value().get_flattened_offset()] << std::endl;
787+
}
788+
789+
790+
}
791+
736792
TEST(Dataset, selValue) {
737793
std::string path = "zarrs/selTester.mdio";
738794
auto dsRes = makePopulated(path);

mdio/variable.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1043,6 +1043,8 @@ class Variable {
10431043
*/
10441044
template <typename... Descriptors>
10451045
Result<Variable> slice(const Descriptors&... descriptors) const {
1046+
std::stringstream ss;
1047+
ss << "Slicing variable: " << variableName << " with descriptors...";
10461048
constexpr size_t numDescriptors = sizeof...(descriptors);
10471049

10481050
auto tuple_descs = std::make_tuple(descriptors...);
@@ -1082,6 +1084,12 @@ class Variable {
10821084
},
10831085
tuple_descs);
10841086

1087+
for (auto i=0; i<labels.size(); ++i) {
1088+
ss << labels[i].label() << " " << start[i] << " " << stop[i] << " " << step[i] << std::endl;
1089+
}
1090+
1091+
std::cout << ss.str() << std::endl;
1092+
10851093
if (preconditionStatus >= 0) {
10861094
mdio::RangeDescriptor<Index> err;
10871095
std::apply(
@@ -1119,6 +1127,7 @@ class Variable {
11191127
store |
11201128
tensorstore::Dims(labels).HalfOpenInterval(start, stop, step));
11211129
// return a new variable with the sliced store
1130+
std::cout << "Sliced variable: " << variableName << " with no cat" << std::endl;
11221131
return Variable{variableName, longName, metadata, slice_store,
11231132
attributes};
11241133
} else if (labelSet.size() != labelSize) {
@@ -1164,6 +1173,7 @@ class Variable {
11641173
tensorstore::TensorStore<T, R, M>(tensorstore::unchecked,
11651174
catStore);
11661175
// Return a new Variable with the concatenated store
1176+
std::cout <<"Sliced variable: " << variableName << "with cat" << std::endl;
11671177
return Variable{variableName, longName, metadata, typedCatStore,
11681178
attributes};
11691179
}
@@ -1174,6 +1184,7 @@ class Variable {
11741184
"Unexpected error occured while trying to slice the Variable.");
11751185
}
11761186
// the slice didn't change anything in the variables dimensions.
1187+
std::cout << "Sliced variable: " << variableName << " didnt' change" << std::endl;
11771188
return *this;
11781189
}
11791190

0 commit comments

Comments
 (0)