Skip to content

Commit dcdd41d

Browse files
committed
Begin prep for final API
1 parent 0126032 commit dcdd41d

File tree

2 files changed

+181
-51
lines changed

2 files changed

+181
-51
lines changed
Lines changed: 149 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -14,44 +14,44 @@
1414
#include "tensorstore/index_space/index_transform_builder.h"
1515
#include "tensorstore/box.h"
1616

17+
#define MDIO_INTERNAL_PROFILING 0 // TODO(BrianMichell): Remove simple profiling code once we approach a more mature API access.
18+
1719
namespace mdio {
1820

21+
#ifdef MDIO_INTERNAL_PROFILING
22+
void timer(std::chrono::high_resolution_clock::time_point start) {
23+
auto end = std::chrono::high_resolution_clock::now();
24+
auto duration = std::chrono::duration_cast<std::chrono::microseconds>(end - start);
25+
std::cout << "Time taken: " << duration.count() << " microseconds" << std::endl;
26+
}
27+
#endif
1928

2029
/// \brief Collects valid index selections per dimension for a Dataset without
2130
/// performing slicing immediately.
2231
///
23-
/// Only dimensions explicitly filtered via add_selection appear in the map;
32+
/// Only dimensions explicitly filtered via filterByCoordinate appear in the map;
2433
/// any dimension not present should be treated as having its full index range.
25-
class IndexSelection {
34+
class CoordinateSelector {
2635
public:
2736
/// Construct from an existing Dataset (captures its full domain).
28-
explicit IndexSelection(const Dataset& dataset)
37+
explicit CoordinateSelector(const Dataset& dataset)
2938
: dataset_(dataset), base_domain_(dataset.domain) {}
3039

3140

3241
template <typename T>
33-
mdio::Future<void> add_selection(const ValueDescriptor<T>& descriptor) {
42+
mdio::Future<void> filterByCoordinate(const ValueDescriptor<T>& descriptor) {
3443
if (kept_runs_.empty()) {
3544
return _init_runs(descriptor);
3645
} else {
3746
return _add_new_run(descriptor);
3847
}
3948
}
4049

41-
/// \brief Emit a RangeDescriptor per surviving tuple coordinate, without coalescing.
42-
std::vector<RangeDescriptor<Index>> range_descriptors() const {
43-
std::vector<mdio::RangeDescriptor<mdio::Index>> ret;
44-
ret.reserve(kept_runs_.size() * kept_runs_[0].size());
45-
for (auto const& run : kept_runs_) {
46-
for (auto const& interval : run) { // TODO: This is not an interval!
47-
ret.emplace_back(RangeDescriptor<mdio::Index>{interval.label, interval.start, interval.stop, 1});
48-
}
49-
}
50-
return ret;
51-
}
52-
5350
template <typename T>
54-
Future<void> sort_runs(const std::string& sort_key) {
51+
Future<void> sortSelectionByKey(const std::string& sort_key) {
52+
#ifdef MDIO_INTERNAL_PROFILING
53+
auto start = std::chrono::high_resolution_clock::now();
54+
#endif
5555
auto non_const_ds = dataset_;
5656
const size_t n = kept_runs_.size();
5757

@@ -66,6 +66,11 @@ class IndexSelection {
6666

6767
std::vector<T> keys;
6868
keys.reserve(n);
69+
#ifdef MDIO_INTERNAL_PROFILING
70+
std::cout << "Set up sorting of " << sort_key << " ... ";
71+
timer(start);
72+
start = std::chrono::high_resolution_clock::now();
73+
#endif
6974
for (auto &f : reads) {
7075
// if (!f.status().ok()) return f.status();
7176
// auto data = f.value();
@@ -77,6 +82,11 @@ class IndexSelection {
7782
// auto n = std::get<3>(resolution); // Not required
7883
keys.push_back(data_ptr[offset]);
7984
}
85+
#ifdef MDIO_INTERNAL_PROFILING
86+
std::cout << "Waiting for reads to complete for " << sort_key << " ... ";
87+
timer(start);
88+
start = std::chrono::high_resolution_clock::now();
89+
#endif
8090

8191
// 2) Build and stable-sort an index array [0…n-1] by key
8292
std::vector<size_t> idx(n);
@@ -85,6 +95,11 @@ class IndexSelection {
8595
idx.begin(), idx.end(),
8696
[&](size_t a, size_t b) { return keys[a] < keys[b]; }
8797
);
98+
#ifdef MDIO_INTERNAL_PROFILING
99+
std::cout << "Sorting time for " << sort_key << " ... ";
100+
timer(start);
101+
start = std::chrono::high_resolution_clock::now();
102+
#endif
88103

89104
// 3) One linear, move-only pass into a temp buffer
90105
using Desc = std::decay_t<decltype(kept_runs_)>::value_type;
@@ -96,11 +111,18 @@ class IndexSelection {
96111

97112
// 4) Steal the buffer back
98113
kept_runs_ = std::move(tmp);
114+
#ifdef MDIO_INTERNAL_PROFILING
115+
std::cout << "Stealing buffer back time for " << sort_key << " ... ";
116+
timer(start);
117+
#endif
99118
return absl::OkStatus();
100119
}
101120

102121
template <typename T>
103-
Future<std::vector<T>> run_values(const std::string& output_variable) {
122+
Future<std::vector<T>> readSelection(const std::string& output_variable) {
123+
#ifdef MDIO_INTERNAL_PROFILING
124+
auto start = std::chrono::high_resolution_clock::now();
125+
#endif
104126
auto non_const_ds = dataset_;
105127
std::vector<Future<VariableData<T>>> reads;
106128
reads.reserve(kept_runs_.size());
@@ -116,6 +138,11 @@ class IndexSelection {
116138
}
117139
}
118140

141+
#ifdef MDIO_INTERNAL_PROFILING
142+
std::cout << "Set up reading of " << output_variable << " ... ";
143+
timer(start);
144+
start = std::chrono::high_resolution_clock::now();
145+
#endif
119146

120147
for (auto& f : reads) {
121148
MDIO_ASSIGN_OR_RETURN(auto resolution, _resolve_future<T>(f));
@@ -127,6 +154,11 @@ class IndexSelection {
127154
std::memcpy(buffer.data(), data_ptr + offset, n * sizeof(T));
128155
ret.insert(ret.end(), buffer.begin(), buffer.end());
129156
}
157+
158+
#ifdef MDIO_INTERNAL_PROFILING
159+
std::cout << "Reading time for " << output_variable << " ... ";
160+
timer(start);
161+
#endif
130162
return ret;
131163
}
132164

@@ -135,9 +167,23 @@ class IndexSelection {
135167
tensorstore::IndexDomain<> base_domain_;
136168
std::vector<std::vector<mdio::RangeDescriptor<mdio::Index>>> kept_runs_;
137169

170+
/*
171+
TODO: The built RangeDescriptors aren't behaving as I hoped.
172+
They are building the longest runs possible properly, however
173+
as it becomes disjointed we start to lose some info.
174+
175+
e.g. We can have [0,1], [0, 25], [0, 120] but
176+
the last dimension is actually [0, 1000].
177+
178+
What we should get instead is [0, 1], [0, 24], [0, 1000] and [0, 1], [24, 25], [0, 120]
179+
*/
180+
138181
template <typename T>
139182
Future<void> _init_runs(const ValueDescriptor<T>& descriptor) {
140183
using Interval = typename Variable<T>::Interval;
184+
#ifdef MDIO_INTERNAL_PROFILING
185+
auto start = std::chrono::high_resolution_clock::now();
186+
#endif
141187
MDIO_ASSIGN_OR_RETURN(auto var, dataset_.variables.get<T>(std::string(descriptor.label.label())));
142188
auto fut = var.Read();
143189
MDIO_ASSIGN_OR_RETURN(auto intervals, var.get_intervals());
@@ -153,6 +199,12 @@ class IndexSelection {
153199

154200
std::size_t run_idx = offset;
155201

202+
#ifdef MDIO_INTERNAL_PROFILING
203+
std::cout << "Initialize and read time... ";
204+
timer(start);
205+
start = std::chrono::high_resolution_clock::now();
206+
#endif
207+
156208
for (mdio::Index idx = offset; idx < offset + n_samples; ++idx) {
157209
bool is_match = data_ptr[idx] == descriptor.value;
158210

@@ -189,19 +241,97 @@ class IndexSelection {
189241
}
190242
}
191243

244+
#ifdef MDIO_INTERNAL_PROFILING
245+
std::cout << "Build runs time... ";
246+
timer(start);
247+
start = std::chrono::high_resolution_clock::now();
248+
#endif
249+
192250
if (local_runs.empty()) {
193251
std::stringstream ss;
194252
ss << "No matches for coordinate '" << descriptor.label.label() << "'";
195253
return absl::NotFoundError(ss.str());
196254
}
197255

198256
kept_runs_ = _from_intervals<T>(local_runs);
257+
#ifdef MDIO_INTERNAL_PROFILING
258+
std::cout << "Finalize time... ";
259+
timer(start);
260+
#endif
199261
return absl::OkStatus();
200262
}
201263

264+
/**
265+
* @brief Using the existing runs, further filter the Dataset by the new coordiante.
266+
*/
202267
template <typename T>
203268
Future<void> _add_new_run(const ValueDescriptor<T>& descriptor) {
204-
return absl::UnimplementedError("Adding selection to an existing IndexSelection is not yet implemented");
269+
using Interval = typename Variable<T>::Interval;
270+
std::vector<std::vector<Interval>> new_runs;
271+
272+
std::vector<std::vector<Interval>> stored_intervals; // Use this to ensure everything remains in memory until the Intervals are no longer needed.
273+
stored_intervals.reserve(kept_runs_.size());
274+
275+
auto non_const_ds = dataset_;
276+
277+
for (const auto& desc : kept_runs_) {
278+
MDIO_ASSIGN_OR_RETURN(auto ds, non_const_ds.isel(desc));
279+
MDIO_ASSIGN_OR_RETURN(auto var, ds.variables.get<T>(std::string(descriptor.label.label())));
280+
auto fut = var.Read();
281+
MDIO_ASSIGN_OR_RETURN(auto intervals, var.get_intervals());
282+
stored_intervals.push_back(std::move(intervals)); // Just to ensure nothing gets freed prematurely.
283+
MDIO_ASSIGN_OR_RETURN(auto resolution, _resolve_future<T>(fut));
284+
auto data = std::get<0>(resolution);
285+
auto data_ptr = std::get<1>(resolution);
286+
auto offset = std::get<2>(resolution);
287+
auto n = std::get<3>(resolution);
288+
289+
auto current_pos = intervals;
290+
bool isInRun = false;
291+
292+
std::size_t run_idx = offset;
293+
294+
for (Index idx = offset; idx < offset + n; ++idx) {
295+
bool is_match = data_ptr[idx] == descriptor.value;
296+
if (is_match && !isInRun) {
297+
isInRun = true;
298+
for (auto i=run_idx; i<idx; ++i) {
299+
_current_position_increment<T>(current_pos, intervals);
300+
}
301+
run_idx = idx;
302+
std::vector<Interval> run = current_pos;
303+
new_runs.push_back(std::move(run));
304+
} else if (is_match && isInRun) {
305+
// Somewhere in the middle of a run
306+
// do nothing TODO: Remove me
307+
} else if (!is_match && isInRun) {
308+
// The end of a run
309+
isInRun = false;
310+
for (auto i=run_idx; i<idx; ++i) {
311+
_current_position_increment<T>(current_pos, intervals);
312+
}
313+
run_idx = idx;
314+
auto& last_run = new_runs.back();
315+
for (auto i=0; i<current_pos.size(); ++i) {
316+
last_run[i].exclusive_max = current_pos[i].inclusive_min + 1;
317+
}
318+
} else if (!is_match && !isInRun) {
319+
// No run at all
320+
// do nothing TODO: Remove me
321+
} else {
322+
// base case TODO: Remove me
323+
}
324+
}
325+
}
326+
327+
if (new_runs.empty()) {
328+
std::stringstream ss;
329+
ss << "No matches for coordinate '" << descriptor.label.label() << "'";
330+
return absl::NotFoundError(ss.str());
331+
}
332+
333+
kept_runs_ = _from_intervals<T>(new_runs); // TODO: We need to ensure we don't accidentally drop any pre-sliced dimensions...
334+
return absl::OkStatus();
205335
}
206336

207337
template <typename T>

0 commit comments

Comments
 (0)