Skip to content

Commit 6d361f3

Browse files
committed
Enable strided index slicing
1 parent 838538d commit 6d361f3

File tree

3 files changed

+243
-29
lines changed

3 files changed

+243
-29
lines changed

mdio/dataset_test.cc

Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,239 @@ TEST(Dataset, isel) {
411411
<< "Inline range should end at 5";
412412
}
413413

414+
TEST(Dataset, iselWithStride) {
415+
// Tests the integrity of data that is written with a strided slice.
416+
std::string iselPath = "zarrs/acceptance";
417+
{ // Scoping the dataset creation to ensure the variables are cleaned up
418+
// before the testing.
419+
auto json_vars = GetToyExample();
420+
auto dataset = mdio::Dataset::from_json(json_vars, iselPath,
421+
mdio::constants::kCreateClean);
422+
ASSERT_TRUE(dataset.status().ok()) << dataset.status();
423+
auto ds = dataset.value();
424+
425+
mdio::RangeDescriptor<mdio::Index> desc1 = {"inline", 0, 256, 2};
426+
auto sliceRes = ds.isel(desc1);
427+
ASSERT_TRUE(sliceRes.status().ok()) << sliceRes.status();
428+
ds = sliceRes.value();
429+
430+
auto ilVarRes = ds.variables.get<mdio::dtypes::uint32_t>("inline");
431+
ASSERT_TRUE(ilVarRes.status().ok()) << ilVarRes.status();
432+
auto ilVar = ilVarRes.value();
433+
434+
auto ilDataRes = mdio::from_variable<mdio::dtypes::uint32_t>(ilVar);
435+
ASSERT_TRUE(ilDataRes.status().ok()) << ilDataRes.status();
436+
auto ilData = ilDataRes.value();
437+
438+
auto ilAccessor = ilData.get_data_accessor().data();
439+
for (uint32_t i = 0; i < 128; i++) {
440+
ilAccessor[i] = i * 2;
441+
}
442+
443+
auto ilFut = ilVar.Write(ilData);
444+
445+
ASSERT_TRUE(ilFut.status().ok()) << ilFut.status();
446+
447+
// --- Begin new QC data generation for the "image" variable (float32) ---
448+
auto imageVarRes = ds.variables.get<mdio::dtypes::float32_t>("image");
449+
ASSERT_TRUE(imageVarRes.status().ok()) << imageVarRes.status();
450+
auto imageVar = imageVarRes.value();
451+
452+
auto imageDataRes = mdio::from_variable<mdio::dtypes::float32_t>(imageVar);
453+
ASSERT_TRUE(imageDataRes.status().ok()) << imageDataRes.status();
454+
auto imageData = imageDataRes.value();
455+
456+
auto imageAccessor = imageData.get_data_accessor().data();
457+
for (uint32_t i = 0; i < 128; i++) {
458+
for (uint32_t j = 0; j < 512; j++) {
459+
for (uint32_t k = 0; k < 384; k++) {
460+
imageAccessor[i * (512 * 384) + j * 384 + k] =
461+
static_cast<float>(i * 2) + j * 0.1f + k * 0.01f;
462+
}
463+
}
464+
}
465+
466+
auto imageWriteFut = imageVar.Write(imageData);
467+
ASSERT_TRUE(imageWriteFut.status().ok()) << imageWriteFut.status();
468+
} // end of scoping the dataset creation to ensure the variables are cleaned
469+
// up before the testing.
470+
471+
auto reopenedDsFut = mdio::Dataset::Open(iselPath, mdio::constants::kOpen);
472+
ASSERT_TRUE(reopenedDsFut.status().ok()) << reopenedDsFut.status();
473+
auto reopenedDs = reopenedDsFut.value();
474+
475+
auto inlineVarRes =
476+
reopenedDs.variables.get<mdio::dtypes::uint32_t>("inline");
477+
ASSERT_TRUE(inlineVarRes.status().ok()) << inlineVarRes.status();
478+
auto inlineVar = inlineVarRes.value();
479+
480+
auto inlineDataFut = inlineVar.Read();
481+
ASSERT_TRUE(inlineDataFut.status().ok()) << inlineDataFut.status();
482+
auto inlineData = inlineDataFut.value();
483+
auto inlineAccessor = inlineData.get_data_accessor().data();
484+
for (uint32_t i = 0; i < 256; i++) {
485+
if (i % 2 == 0) {
486+
ASSERT_EQ(inlineAccessor[i], i) << "Expected inline value to be " << i
487+
<< " but got " << inlineAccessor[i];
488+
} else {
489+
ASSERT_EQ(inlineAccessor[i], 0)
490+
<< "Expected inline value to be 0 but got " << inlineAccessor[i];
491+
}
492+
}
493+
494+
auto imageVarResReopen =
495+
reopenedDs.variables.get<mdio::dtypes::float32_t>("image");
496+
ASSERT_TRUE(imageVarResReopen.status().ok()) << imageVarResReopen.status();
497+
auto imageVarReopen = imageVarResReopen.value();
498+
499+
auto imageDataFut = imageVarReopen.Read();
500+
ASSERT_TRUE(imageDataFut.status().ok()) << imageDataFut.status();
501+
auto imageDataFull = imageDataFut.value();
502+
auto imageAccessorFull = imageDataFull.get_data_accessor().data();
503+
504+
// Instead of checking all 256x512x384 elements (which can be very time
505+
// consuming), we check a few sample indices. For full "image" variable, for
506+
// every full inline index i: if (i % 2 == 0): the expected value is i +
507+
// j*0.1f + k*0.01f, otherwise NaN.
508+
std::vector<uint32_t> sample_i = {0, 1, 2,
509+
255}; // mix of even and odd indices
510+
std::vector<uint32_t> sample_j = {0, 256, 511};
511+
std::vector<uint32_t> sample_k = {0, 100, 383};
512+
513+
for (auto i : sample_i) {
514+
for (auto j : sample_j) {
515+
for (auto k : sample_k) {
516+
size_t index = i * (512 * 384) + j * 384 + k;
517+
float actual = imageAccessorFull[index];
518+
519+
if (i % 2 == 0) {
520+
// For even indices, we expect a specific value
521+
float expected = static_cast<float>(i) + j * 0.1f + k * 0.01f;
522+
ASSERT_FLOAT_EQ(actual, expected)
523+
<< "QC mismatch in image variable at (" << i << ", " << j << ", "
524+
<< k << ")";
525+
} else {
526+
// For odd indices, we expect NaN
527+
ASSERT_TRUE(std::isnan(actual))
528+
<< "Expected NaN at (" << i << ", " << j << ", " << k
529+
<< ") but got " << actual;
530+
}
531+
}
532+
}
533+
}
534+
// --- End new QC check for the "image" variable ---
535+
}
536+
537+
TEST(Dataset, iselWithStrideAndExistingData) {
538+
std::string testPath = "zarrs/slice_scale_test";
539+
float scaleFactor = 2.5f;
540+
541+
// --- Step 1: Initialize the entire image variable with QC values and Write it ---
542+
{
543+
// Create a new dataset
544+
auto json_vars = GetToyExample();
545+
auto dataset = mdio::Dataset::from_json(json_vars, testPath, mdio::constants::kCreateClean);
546+
ASSERT_TRUE(dataset.status().ok()) << dataset.status();
547+
auto ds = dataset.value();
548+
549+
// Get the "image" variable (expected to be float32_t type)
550+
auto imageVarRes = ds.variables.get<mdio::dtypes::float32_t>("image");
551+
ASSERT_TRUE(imageVarRes.status().ok()) << imageVarRes.status();
552+
auto imageVar = imageVarRes.value();
553+
554+
auto imageDataRes = mdio::from_variable<mdio::dtypes::float32_t>(imageVar);
555+
ASSERT_TRUE(imageDataRes.status().ok()) << imageDataRes.status();
556+
auto imageData = imageDataRes.value();
557+
auto imageAccessor = imageData.get_data_accessor().data();
558+
559+
// Initialize the entire "image" variable with QC values.
560+
// For this test, we assume dimensions 256 x 512 x 384.
561+
for (uint32_t i = 0; i < 256; i++) {
562+
for (uint32_t j = 0; j < 512; j++) {
563+
for (uint32_t k = 0; k < 384; k++) {
564+
imageAccessor[i * (512 * 384) + j * 384 + k] =
565+
static_cast<float>(i) + j * 0.1f + k * 0.01f;
566+
}
567+
}
568+
}
569+
570+
auto writeFut = imageVar.Write(imageData);
571+
ASSERT_TRUE(writeFut.status().ok()) << writeFut.status();
572+
} // End of Step 1
573+
574+
// --- Step 2: Slice with stride of 2 and scale the values of the "image" variable ---
575+
{
576+
// Re-open the dataset for modifications.
577+
auto reopenedDsFut = mdio::Dataset::Open(testPath, mdio::constants::kOpen);
578+
ASSERT_TRUE(reopenedDsFut.status().ok()) << reopenedDsFut.status();
579+
auto ds = reopenedDsFut.value();
580+
581+
// Slice the dataset along the "inline" dimension using a stride of 2.
582+
mdio::RangeDescriptor<mdio::Index> desc = {"inline", 0, 256, 2};
583+
auto sliceRes = ds.isel(desc);
584+
ASSERT_TRUE(sliceRes.status().ok()) << sliceRes.status();
585+
auto ds_slice = sliceRes.value();
586+
587+
// Get the "image" variable from the sliced dataset.
588+
auto imageVarRes = ds_slice.variables.get<mdio::dtypes::float32_t>("image");
589+
ASSERT_TRUE(imageVarRes.status().ok()) << imageVarRes.status();
590+
auto imageVar = imageVarRes.value();
591+
592+
auto imageDataFut = imageVar.Read();
593+
ASSERT_TRUE(imageDataFut.status().ok()) << imageDataFut.status();
594+
auto imageData = imageDataFut.value();
595+
auto imageAccessor = imageData.get_data_accessor().data();
596+
597+
// The sliced "image" now has dimensions 128 x 512 x 384 because we selected every 2nd index.
598+
// Scale each element in the slice by 'scaleFactor'
599+
for (uint32_t ii = 0; ii < 128; ii++) { // 'ii' corresponds to original index i = ii * 2.
600+
for (uint32_t j = 0; j < 512; j++) {
601+
for (uint32_t k = 0; k < 384; k++) {
602+
size_t index = ii * (512 * 384) + j * 384 + k;
603+
imageAccessor[index] *= scaleFactor;
604+
}
605+
}
606+
}
607+
// Write the updated (scaled) data back to the dataset.
608+
auto writeFut = imageVar.Write(imageData);
609+
ASSERT_TRUE(writeFut.status().ok()) << writeFut.status();
610+
} // End of Step 2
611+
612+
// --- Step 3: Read the entire image variable and validate QC values ---
613+
{
614+
// Re-open the dataset for the final validation.
615+
auto reopenedDsFut = mdio::Dataset::Open(testPath, mdio::constants::kOpen);
616+
ASSERT_TRUE(reopenedDsFut.status().ok()) << reopenedDsFut.status();
617+
auto ds = reopenedDsFut.value();
618+
619+
auto imageVarRes = ds.variables.get<mdio::dtypes::float32_t>("image");
620+
ASSERT_TRUE(imageVarRes.status().ok()) << imageVarRes.status();
621+
auto imageVar = imageVarRes.value();
622+
623+
auto imageReadFut = imageVar.Read();
624+
ASSERT_TRUE(imageReadFut.status().ok()) << imageReadFut.status();
625+
auto imageData = imageReadFut.value();
626+
auto imageAccessor = imageData.get_data_accessor().data();
627+
628+
// Validate the values over the entire "image" variable.
629+
// For even inline indices (i % 2 == 0) we expect the initial QC value scaled by 'scaleFactor'.
630+
// For odd inline indices, the original QC values should remain.
631+
for (uint32_t i = 0; i < 256; i++) {
632+
for (uint32_t j = 0; j < 512; j++) {
633+
for (uint32_t k = 0; k < 384; k++) {
634+
size_t index = i * (512 * 384) + j * 384 + k;
635+
float baseValue = static_cast<float>(i) + j * 0.1f + k * 0.01f;
636+
float expected = (i % 2 == 0) ? baseValue * scaleFactor : baseValue;
637+
auto val = imageAccessor[index];
638+
ASSERT_FLOAT_EQ(val, expected)
639+
<< "Mismatch at (" << i << ", " << j << ", " << k
640+
<< "): expected " << expected << ", but got " << val;
641+
}
642+
}
643+
}
644+
} // End of Step 3
645+
}
646+
414647
TEST(Dataset, selValue) {
415648
std::string path = "zarrs/selTester.mdio";
416649
auto dsRes = makePopulated(path);

mdio/variable.h

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1013,7 +1013,6 @@ class Variable {
10131013
* @brief Slices the Variable along the specified dimensions and returns the
10141014
* resulting sub-Variable. This slice is performed as a half open interval.
10151015
* Dimensions that are not described will remain fully intact.
1016-
* @pre The step of the slice descriptor must be 1.
10171016
* @pre The start of the slice descriptor must be less than the stop.
10181017
* @post The resulting Variable will be sliced along the specified dimensions
10191018
* within it's domain. If the slice lay outside of the domain of the Variable,
@@ -1043,7 +1042,6 @@ class Variable {
10431042
stop.reserve(numDescriptors);
10441043
step.reserve(numDescriptors);
10451044
// -1 Everything is ok
1046-
// -2 Error: Step is not 1
10471045
// >=0 Error: Start is greater than or equal to stop
10481046
int8_t preconditionStatus = -1;
10491047

@@ -1052,10 +1050,6 @@ class Variable {
10521050
size_t idx = 0;
10531051
((
10541052
[&] {
1055-
if (desc.step != 1) {
1056-
preconditionStatus = -2;
1057-
return -2;
1058-
}
10591053
auto clampedDesc = sliceInRange(desc);
10601054
if (clampedDesc.start > clampedDesc.stop) {
10611055
preconditionStatus = idx;
@@ -1075,10 +1069,7 @@ class Variable {
10751069
},
10761070
tuple_descs);
10771071

1078-
if (preconditionStatus == -2) {
1079-
return absl::InvalidArgumentError(
1080-
"Only step 1 is supported for slicing.");
1081-
} else if (preconditionStatus >= 0) {
1072+
if (preconditionStatus >= 0) {
10821073
mdio::RangeDescriptor<Index> err;
10831074
std::apply(
10841075
[&](const auto&... desc) {
@@ -1597,8 +1588,6 @@ struct LabeledArray {
15971588

15981589
tensorstore::DimensionIndexBuffer buffer;
15991590

1600-
bool preconditionStatus = true;
1601-
16021591
absl::Status overall_status = absl::OkStatus();
16031592
std::apply(
16041593
[&](const auto&... desc) {
@@ -1614,21 +1603,13 @@ struct LabeledArray {
16141603
overall_status = result; // Capture the error status
16151604
return; // Exit lambda on error
16161605
}
1617-
if (desc.step != 1) {
1618-
preconditionStatus = false;
1619-
}
16201606
dims[idx] = buffer[0];
16211607
}(),
16221608
idx++),
16231609
...);
16241610
},
16251611
tuple_descs);
16261612

1627-
if (!preconditionStatus) {
1628-
return absl::InvalidArgumentError(
1629-
"Only step 1 is supported for slicing.");
1630-
}
1631-
16321613
/// could be we can't slice a dimension
16331614
if (!overall_status.ok()) {
16341615
return overall_status;

mdio/variable_test.cc

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -755,7 +755,7 @@ TEST(Variable, outOfBoundsSlice) {
755755
EXPECT_THAT(badDomain.dimensions().shape(), ::testing::ElementsAre(250, 500))
756756
<< badDomain.dimensions();
757757

758-
mdio::RangeDescriptor<mdio::Index> illegal_step = {"x", 0, 500, 2};
758+
// mdio::RangeDescriptor<mdio::Index> illegal_step = {"x", 0, 500, 2};
759759
// var = mdio::Variable<>::Open(json_good,
760760
// mdio::constants::kCreateClean).result(); auto illegal =
761761
// var.value().slice(illegal_step); EXPECT_FALSE(illegal.status().ok()) <<
@@ -773,12 +773,12 @@ TEST(Variable, outOfBoundsSlice) {
773773
// 500, 2};
774774
auto var1 =
775775
mdio::Variable<>::Open(json_good, mdio::constants::kCreateClean).result();
776-
auto illegal = var1.value().slice(illegal_step);
777-
EXPECT_FALSE(illegal.status().ok())
778-
<< "Step precondition was violated but still sliced";
776+
// auto illegal = var1.value().slice(illegal_step);
777+
// EXPECT_FALSE(illegal.status().ok())
778+
// << "Step precondition was violated but still sliced";
779779

780780
mdio::RangeDescriptor<mdio::Index> illegal_start_stop = {"x", 500, 0, 1};
781-
illegal = var1.value().slice(illegal_start_stop);
781+
auto illegal = var1.value().slice(illegal_start_stop);
782782
EXPECT_FALSE(illegal.status().ok())
783783
<< "Start stop precondition was violated but still sliced";
784784

@@ -895,10 +895,10 @@ TEST(VariableData, outOfBoundsSlice) {
895895
EXPECT_FALSE(outbounds.status().ok())
896896
<< "Slicing out of bounds should fail but did not";
897897

898-
mdio::RangeDescriptor<mdio::Index> illegal_step = {"x", 0, 500, 2};
899-
auto illegal = varData.slice(illegal_step);
900-
EXPECT_FALSE(illegal.status().ok())
901-
<< "Step precondition was violated but still sliced";
898+
// mdio::RangeDescriptor<mdio::Index> illegal_step = {"x", 0, 500, 2};
899+
// auto illegal = varData.slice(illegal_step);
900+
// EXPECT_FALSE(illegal.status().ok())
901+
// << "Step precondition was violated but still sliced";
902902
}
903903

904904
TEST(VariableSpec, open) {

0 commit comments

Comments
 (0)