Skip to content

Refactor: unify task detection into one core; derive modality from model class, not config-field heuristics #877

@timenick

Description

@timenick

Summary

detect_task (used by inspect/eval) and _detect_task_and_class_from_config
(used by config/build) are two implementations of the same task decision over
the same MODEL_CLASS_MAPPING data, and a third entry point
(resolve_loader_config step 2, used by --model-type) resolves the task a fourth
way via get_supported_tasks(type)[0]. They have drifted, and the
modality-disambiguation step (D2) reconstructs modality from config field names —
a heuristic that is provably weaker than information the pipeline already holds.

Proposal: extract a single task-detection core + a single task-override lookup that
all three entry points consult, and derive modality from the resolved model
class's main_input_name.

Internal refactor only — public detect_task / resolve_task_and_model_class
signatures stay unchanged.

Motivation 1: the D2 false positive (concrete bug)

D2 upgrades feature-extraction → image-feature-extraction when the config has a
top-level image_size/patch_size (OR semantics). patch_size is not exclusive
to vision
— spectrogram transformers patchify their mel-spectrogram. Verified via
the real detection path:

declared architecture inspect task today correct
Wav2Vec2Model feature-extraction → text dataset → fails audio
WhisperModel feature-extraction → text dataset → fails audio
ASTModel image-feature-extraction → image dataset → fails audio

Motivation 2: sam2 resolves three different ways (entry-point drift)

sam2 is a multi-task type (MODEL_CLASS_MAPPING registers feature-extraction,
image-feature-extraction, image-segmentation, mask-generation) plus a
(sam2, None) sentinel whose canonical export target is the mask-generation decoder.
The three entry points disagree (verified):

entry point path sam2 result
inspect (detect_task) short-circuit fires only for a single real task → falls through to TasksManager on Sam2Model (headless) feature-extraction
build -m <ckpt> (_detect_task_and_class_from_config) applies the (sam2, None) sentinel mask-generation (decoder)
config/build --model-type sam2 (resolve_loader_config step 2) get_supported_tasks("sam2")[0] — bypasses the sentinel entirely feature-extraction (encoder)

So build -m <sam2> produces the decoder while build --model-type sam2 produces
the encoder — different artifacts for the same model — and inspect predicts
neither reliably. sam (single real task) does not exhibit this because its
short-circuit and sentinel both yield mask-generation.

Root cause

Modality is a property of the model class (known with certainty), but the pipeline
collapses class → task through TasksManager (modality-blind by design), discarding
modality, then D2 reconstructs it from config fields. The class that carries the
authoritative signal is already resolved at that exact point
(_detect_task_from_config_resolve_model_class_from_config), so D2 pays a
heuristic's fragility to avoid a cost already incurred.

main_input_name (an HF framework convention) is the authoritative, offline,
architecture-agnostic modality signal:

main_input_name modality upgrade applied to feature-extraction
input_ids text none → feature-extraction
pixel_values image image-feature-extraction
input_values / input_features audio none → stays feature-extraction (see Decisions (b))

It also handles the CLIP text/vision split correctly
(CLIPTextModelWithProjectioninput_ids, CLIPVisionModelWithProjectionpixel_values),
which the config-field table cannot, and stops the AST→image misroute without any new
task name.

Known limit (intentional): main_input_name separates text/image/audio but is
not injective for image vs video — classic video transformers (VideoMAE /
Timesformer / Vivit) share pixel_values. Video is out of scope (optimum has no ONNX
export config for these types); the image/video case is handled by the calibration
fallback below, not by main_input_name.

Four inconsistencies the merge eliminates

  1. Model-type override (single vs multi task): detect_task's distinct_tasks
    short-circuit fires only when a type maps to exactly one real task; the
    (model_type, None) sentinel is ignored on the detect path. So multi-task types
    with a sentinel (sam2 / sam2-video) fall through to TasksManager and report the
    headless feature-extraction, while build's _detect_task_and_class_from_config
    applies the sentinel → mask-generation. (fix(task): make detect_task architecture-aware for multi-task model types #841 synced the single-task half only.)
  2. --model-type path bypasses the override: resolve_loader_config step 2 uses
    get_supported_tasks(model_type)[0] for the --model-type entry, consulting
    neither the short-circuit nor the sentinel. This is the third sam2 answer above.
  3. Model-id override: get_default_task_for_model_id (e.g. prajjwal1/bert-tiny)
    is applied on the build path only; detect_task skips it, so inspect can
    disagree with build today.
  4. Modality signal: config fields (the AST bug) vs the resolved class.

Decision: sam2 / sam2-video resolve to mask-generation on every entry point

For multi-task types with a (type, None) sentinel, the sentinel's canonical export
target is authoritative on all entry points. So inspect, build -m, and
build --model-type all resolve sam2 → mask-generation (matching sam, which
already does). --task feature-extraction / image-feature-extraction still select
the encoder explicitly. This changes sam2/sam2-video inspect and --model-type
output, and updates test_detect_task_falls_through_for_multi_task_model_type_sam2
(which pinned the old fall-through) with justification.

Proposed architecture

_resolve_task_override(model_type, model_id) -> task | None
        single place encoding model_type / model_id canonical task overrides:
          1. model-id default  (get_default_task_for_model_id)
          2. (model_type, None) sentinel -> reverse-lookup default task
          3. single real task  (exactly one non-None task for the model_type)
        (replaces the short-circuit, the sentinel reverse-lookup, AND becomes the
         first consult in resolve_loader_config step 2)

_detect(config) -> (task, model_class | None, source)
        the one task-detection core: override -> wrapped-library -> resolve class
        -> infer task -> fill-mask->seq2seq upgrade

detect_task(config)               = _detect -> modality-upgrade -> drop class
resolve_task_and_model_class C1   = _detect -> ensure class -> modality-upgrade
resolve_loader_config step 2      = _resolve_task_override(model_type) first;
                                    only fall to get_supported_tasks(type)[0] if None

All three entry points now consult _resolve_task_override before any fallback, so
sam2 is consistent everywhere. Build-specific class resolution
(get_model_class_for_task, specialization, arch fallback) stays in the build entry
layer. The short-circuit's "answer without importing optimum" optimization is
preserved when the override hits.

Decisions

  • (a) Remove the D2 config-field table outright. Every path yielding
    feature-extraction either comes from the override mapping (already
    modality-aware) or from TasksManager (class resolved → main_input_name
    available). No path holds feature-extraction without a class, so the
    heuristic is dead weight. Keep _resolve_task_modality as the single modality
    entry point, re-implemented on main_input_name.

  • (b) Do NOT add audio-feature-extraction; the modality upgrade fires for image
    only.
    Audio backbones stay feature-extraction — which is exactly what optimum
    exports them as (verified: wav2vec2/whisper/ast/hubert/wavlm list
    feature-extraction, not an audio variant). Switching to main_input_name already
    fixes the AST→image misroute without a new name. audio-feature-extraction has
    no HF pipeline, no evaluator, no dataset, and no optimum task behind it; adding
    it to KNOWN_TASKS would advertise a capability that does not exist. It is
    introduced only when a real audio feature-extraction downstream (evaluator +
    dataset + recipe) is built — at which point the name is earned, as
    image-feature-extraction already is.

    Principle: a modality-aware task variant exists only when a modality-specific
    dataset + evaluator exists for it. image-feature-extraction qualifies;
    audio-feature-extraction does not yet.

In scope: universal calibration fallback

With (b), a bare audio backbone stays feature-extraction, so quant calibration would
still route to TextDataset and crash (audio models have no tokenizer). Fix this
modality-agnostically: when the calibration dataset's produced inputs don't intersect
the ONNX input names, fall back to RandomDataset (reads the real inputs straight
from the model). This also covers the image/video collision and any future modality,
without inventing per-modality task names. (Promoted from out-of-scope — it is the
mechanism that keeps build from crashing once audio stays feature-extraction.)

Out of scope

  • A real audio-feature-extraction / audio-classification evaluator + default dataset.
  • Video task support — optimum has no ONNX export config for videomae/timesformer/vivit
    (explicitly unsupported upstream).

Testing

Parametrized pytest over representative configs (text / image / audio / CLIP-dual /
SAM / SAM2 / bert-tiny), asserting:

  • detect_task == task from resolve_task_and_model_class == task from
    resolve_loader_config(model_type=...) for the same model — the cross-entry-point
    consistency guarantee.
  • sam2 / sam2-video → mask-generation on all three entry points (was
    feature-extraction on inspect and --model-type); update
    test_detect_task_falls_through_for_multi_task_model_type_sam2.
  • AST → feature-extraction (no longer image-feature-extraction); audio backbone
    build completes via the RandomDataset fallback.
  • bert-tiny model-id override fires on the detect path too.
  • vision backbones still → image-feature-extraction; no regression on the fix(task): make detect_task architecture-aware for multi-task model types #841
    table (bart-mnli, sam, clip).

References

Metadata

Metadata

Assignees

Labels

P1High — major feature broken or significant UX impactbugSomething isn't workingqualityUse for quality control related issuesrefactorCode refactoringtriagedIssue has been triaged

Type

No type
No fields configured for issues without a type.

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions