Skip to content

Commit 2920718

Browse files
committed
update unit tests
Signed-off-by: Paul Dittamo <[email protected]>
1 parent e4500c3 commit 2920718

File tree

2 files changed

+106
-0
lines changed

2 files changed

+106
-0
lines changed

tests/flytekit/unit/core/test_array_node.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,3 +378,47 @@ def test_map_task_wrapper():
378378

379379
mapped_lp = map_task(lp)(a=[1, 3, 5], b=[2, 4, 6], c=[7, 8, 9])
380380
assert mapped_lp == [14, 96, 270]
381+
382+
383+
def test_run_all_sub_nodes_default():
384+
node = array_node(lp, concurrency=10, min_success_ratio=0.9)
385+
assert node.run_all_sub_nodes is False
386+
387+
388+
def test_run_all_sub_nodes_set():
389+
node = array_node(lp, concurrency=10, min_success_ratio=0.9, run_all_sub_nodes=True)
390+
assert node.run_all_sub_nodes is True
391+
392+
393+
def test_run_all_sub_nodes_serialization(serialization_settings):
394+
@workflow
395+
def wf_run_all() -> typing.List[int]:
396+
return map_task(lp, concurrency=10, min_success_ratio=0.9, run_all_sub_nodes=True)(
397+
a=[1, 3, 5], b=["two", 4, "six"], c=[7, 8, 9]
398+
)
399+
400+
od = OrderedDict()
401+
wf_spec = get_serializable(od, serialization_settings, wf_run_all)
402+
403+
parent_node = wf_spec.template.nodes[0]
404+
assert parent_node.array_node._run_all_sub_nodes is True
405+
406+
pb = parent_node.array_node.to_flyte_idl()
407+
assert pb.run_all_sub_nodes is True
408+
409+
410+
def test_run_all_sub_nodes_serialization_default(serialization_settings):
411+
@workflow
412+
def wf_no_run_all() -> typing.List[int]:
413+
return map_task(lp, concurrency=10, min_success_ratio=0.9)(
414+
a=[1, 3, 5], b=["two", 4, "six"], c=[7, 8, 9]
415+
)
416+
417+
od = OrderedDict()
418+
wf_spec = get_serializable(od, serialization_settings, wf_no_run_all)
419+
420+
parent_node = wf_spec.template.nodes[0]
421+
assert parent_node.array_node._run_all_sub_nodes is False
422+
423+
pb = parent_node.array_node.to_flyte_idl()
424+
assert pb.run_all_sub_nodes is False

tests/flytekit/unit/core/test_array_node_map_task.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,7 @@ def wf1(x: typing.List[int]):
578578
assert array_node.array_node._parallelism == 10
579579
assert not array_node.array_node._is_original_sub_node_interface
580580
assert array_node.array_node._execution_mode == _core_workflow.ArrayNode.MINIMAL_STATE
581+
assert not array_node.array_node._run_all_sub_nodes
581582
task_spec = od[arraynode_maptask]
582583
assert task_spec.template.metadata.retries.retries == 2
583584
assert task_spec.template.metadata.interruptible
@@ -588,6 +589,67 @@ def wf1(x: typing.List[int]):
588589
assert array_node.array_node._execution_mode == _core_workflow.ArrayNode.FULL_STATE
589590

590591

592+
def test_run_all_sub_nodes_default():
593+
@task
594+
def t1(a: int) -> int:
595+
return a + 1
596+
597+
mt = map_task(t1)
598+
assert mt.run_all_sub_nodes is False
599+
600+
601+
def test_run_all_sub_nodes_set():
602+
@task
603+
def t1(a: int) -> int:
604+
return a + 1
605+
606+
mt = map_task(t1, run_all_sub_nodes=True)
607+
assert mt.run_all_sub_nodes is True
608+
609+
610+
def test_run_all_sub_nodes_serialization(serialization_settings):
611+
@task
612+
def t1(a: int) -> int:
613+
return a + 1
614+
615+
arraynode_maptask = map_task(t1, run_all_sub_nodes=True)
616+
617+
@workflow
618+
def wf(x: typing.List[int]):
619+
return arraynode_maptask(a=x)
620+
621+
od = OrderedDict()
622+
wf_spec = get_serializable(od, serialization_settings, wf)
623+
624+
array_node = wf_spec.template.nodes[0]
625+
assert array_node.array_node._run_all_sub_nodes is True
626+
627+
# Verify it serializes to the protobuf correctly
628+
pb = array_node.array_node.to_flyte_idl()
629+
assert pb.run_all_sub_nodes is True
630+
631+
632+
def test_run_all_sub_nodes_serialization_default(serialization_settings):
633+
@task
634+
def t1(a: int) -> int:
635+
return a + 1
636+
637+
arraynode_maptask = map_task(t1)
638+
639+
@workflow
640+
def wf(x: typing.List[int]):
641+
return arraynode_maptask(a=x)
642+
643+
od = OrderedDict()
644+
wf_spec = get_serializable(od, serialization_settings, wf)
645+
646+
array_node = wf_spec.template.nodes[0]
647+
assert array_node.array_node._run_all_sub_nodes is False
648+
649+
pb = array_node.array_node.to_flyte_idl()
650+
assert pb.run_all_sub_nodes is False
651+
652+
591653
def test_serialization_extended_resources(serialization_settings):
592654
@task(
593655
accelerator=GPUAccelerator("test_gpu"),

0 commit comments

Comments
 (0)