Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 107 additions & 22 deletions plugins/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
|
"""
import base64
import csv
import contextlib
import io
import multiprocessing.dummy
import os
from packaging.version import Version
Expand Down Expand Up @@ -491,9 +493,6 @@ def _import_media_and_labels_inputs(ctx, inputs):
),
view=file_explorer,
)
data_path = _parse_path(ctx, "data_path")
if data_path is None:
return False

labels_path_type = _get_labels_path_type(dataset_type)

Expand Down Expand Up @@ -537,7 +536,11 @@ def _import_media_and_labels_inputs(ctx, inputs):
prop.error_message = f"Please provide a {ext} path"
return False

_add_label_types(ctx, inputs, dataset_type)
data_path = _parse_path(ctx, "data_path")
if data_path is None:
return False

_add_importer_extras(ctx, inputs, dataset_type)

inputs.bool(
"dynamic",
Expand Down Expand Up @@ -694,34 +697,92 @@ def _import_labels_only_inputs(ctx, inputs):
if dataset_dir is None:
return False

_add_label_types(ctx, inputs, dataset_type)
_add_importer_extras(ctx, inputs, dataset_type)

# Don't allow delegation when uploading files
return tab != "UPLOAD"


def _add_label_types(ctx, inputs, dataset_type):
supported_types = _get_dataset_type(dataset_type).get("label_types", None)
def _add_importer_extras(ctx, inputs, dataset_type):
d = _get_dataset_type(dataset_type)
dataset_type = d["dataset_type"]
supported_types = d.get("label_types", None)

if supported_types is None or len(supported_types) <= 1:
return
if dataset_type == fot.CSVDataset:
dataset_dir = _parse_path(ctx, "dataset_dir")
labels_path = _parse_path(ctx, "labels_path")
_, labels_bytes = _parse_file(ctx, "labels_file")

if dataset_dir is not None:
if labels_path is not None:
labels_path = fos.join(dataset_dir, labels_path)
else:
labels_path = fos.join(dataset_dir, "labels.csv")

if labels_path is not None:
_get_csv_import_fields(ctx, inputs, csv_path=labels_path)

if labels_bytes is not None:
_get_csv_import_fields(ctx, inputs, csv_bytes=labels_bytes)

if supported_types is not None and len(supported_types) > 1:
label_type_choices = types.DropdownView(multiple=True)
for label_type in supported_types:
label_type_choices.add_choice(label_type, label=label_type)

label_type_choices = types.DropdownView(multiple=True)
for label_type in supported_types:
label_type_choices.add_choice(label_type, label=label_type)
inputs.list(
"label_types",
types.String(),
default=None,
label="Label types",
description=(
"The label type(s) to load. By default, all label types are "
"loaded"
),
view=label_type_choices,
)


def _get_csv_import_fields(ctx, inputs, csv_path=None, csv_bytes=None):
fieldnames = _get_csv_fieldnames(csv_path=csv_path, csv_bytes=csv_bytes)

field_choices = types.DropdownView(multiple=True)
for field in fieldnames:
field_choices.add_choice(field, label=field)

inputs.list(
"label_types",
"csv_fields",
types.String(),
required=False,
default=None,
label="Label types",
label="Fields",
description="An optional subset of column(s) to import",
view=field_choices,
)

inputs.str(
"media_field",
required=True,
default="filepath" if "filepath" in fieldnames else None,
label="Media field",
description=(
"The label type(s) to load. By default, all label types are loaded"
"The name of the column containing the media path for each row"
),
view=label_type_choices,
view=field_choices,
)


def _get_csv_fieldnames(csv_path=None, csv_bytes=None):
if csv_path is not None:
f = fos.open_file(csv_path, "r")
else:
f = io.StringIO(csv_bytes.decode("utf-8")) # pylint: disable=no-member

with f:
reader = csv.DictReader(f)
return list(reader.fieldnames)


def _upload_media_inputs(ctx, inputs):
style = ctx.params.get("style", None)

Expand Down Expand Up @@ -776,11 +837,9 @@ def _upload_media_inputs(ctx, inputs):


def _upload_media_bytes(ctx):
media_obj = ctx.params["media_file"]
filename, content = _parse_file(ctx, "media_file")
upload_dir = _parse_path(ctx, "upload_dir")
overwrite = ctx.params["overwrite"]
filename = media_obj["name"]
content = base64.b64decode(media_obj["content"])

if overwrite:
outpath = fos.join(upload_dir, filename)
Expand Down Expand Up @@ -864,6 +923,15 @@ def _import_media_and_labels(ctx):
if label_types is not None:
kwargs["label_types"] = label_types

if dataset_type == fot.CSVDataset:
csv_fields = ctx.params.get("csv_fields", None)
if csv_fields is not None:
kwargs["fields"] = csv_fields

media_field = ctx.params.get("media_field", None)
if media_field is not None:
kwargs["media_field"] = media_field

# @todo can remove version check if we require `fiftyone>=1.6.0`
if ctx.delegated and Version(foc.VERSION) >= Version("1.6.0"):
progress = lambda pb: ctx.set_progress(progress=pb.progress)
Expand All @@ -885,9 +953,7 @@ def _import_media_and_labels(ctx):


def _upload_labels_bytes(ctx, tmp_dir):
labels_obj = ctx.params["labels_file"]
filename = labels_obj["name"]
content = base64.b64decode(labels_obj["content"])
filename, content = _parse_file(ctx, "labels_file")

outpath = fos.join(tmp_dir, filename)
fos.write_file(content, outpath)
Expand All @@ -909,6 +975,15 @@ def _import_labels_only(ctx):
if label_types is not None:
kwargs["label_types"] = label_types

if dataset_type == fot.CSVDataset:
csv_fields = ctx.params.get("csv_fields", None)
if csv_fields is not None:
kwargs["fields"] = csv_fields

media_field = ctx.params.get("media_field", None)
if media_field is not None:
kwargs["media_field"] = media_field

# @todo can remove version check if we require `fiftyone>=1.6.0`
if ctx.delegated and Version(foc.VERSION) >= Version("1.6.0"):
progress = lambda pb: ctx.set_progress(progress=pb.progress)
Expand Down Expand Up @@ -2824,6 +2899,16 @@ def _to_path(value):
return {"absolute_path": value}


def _parse_file(ctx, key):
file_obj = ctx.params.get(key, None)
if file_obj is None:
return None, None

filename = file_obj["name"]
content = base64.b64decode(file_obj["content"])
return filename, content


def _to_list(value):
if value is None:
return None
Expand Down