Skip to content

Commit b0f7b55

Browse files
feat(z-image): add Z-Image Base (undistilled) model variant support (#8799)
* feat(z-image): add Z-Image Base (undistilled) model variant support - Add ZImageVariantType enum with 'turbo' and 'zbase' variants - Auto-detect variant on import via scheduler_config.json shift value (3.0=turbo, 6.0=zbase) - Add database migration to populate variant field for existing Z-Image models - Re-add LCM scheduler with variant-aware filtering (LCM hidden for zbase) - Auto-reset scheduler to Euler when switching to zbase model if LCM selected - Update frontend to show/hide LCM option based on model variant - Add toast notification when scheduler is auto-reset Z-Image Base models are undistilled and require more steps (28-50) with higher guidance (3.0-5.0), while Z-Image Turbo is distilled for ~8 steps with CFG 1.0. LCM scheduler only works with distilled (Turbo) models. * Chore ruff format * Chore fix windows path * feat(z-image): filter LoRAs by variant compatibility and warn on mismatch LoRA picker now hides Z-Image LoRAs with incompatible variants (e.g. ZBase LoRAs when using Turbo model). LoRAs without a variant are always shown. Backend loaders warn at runtime if a LoRA variant doesn't match the transformer variant. * Chore typegen --------- Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
1 parent 467ae66 commit b0f7b55

File tree

19 files changed

+302
-39
lines changed

19 files changed

+302
-39
lines changed

invokeai/app/invocations/z_image_denoise.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,8 @@ class ZImageDenoiseInvocation(BaseInvocation):
107107
# Scheduler selection for the denoising process
108108
scheduler: ZIMAGE_SCHEDULER_NAME_VALUES = InputField(
109109
default="euler",
110-
description="Scheduler (sampler) for the denoising process. Euler is the default and recommended for "
111-
"Z-Image-Turbo. Heun is 2nd-order (better quality, 2x slower). LCM is optimized for few steps.",
110+
description="Scheduler (sampler) for the denoising process. Euler is the default and recommended. "
111+
"Heun is 2nd-order (better quality, 2x slower). LCM works with Turbo only (not Base).",
112112
ui_choice_labels=ZIMAGE_SCHEDULER_LABELS,
113113
)
114114

@@ -387,12 +387,11 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
387387
num_train_timesteps=1000,
388388
shift=1.0,
389389
)
390-
# Set timesteps - LCM should use num_inference_steps (it has its own sigma schedule),
390+
# Set timesteps - LCM uses its own sigma schedule (num_inference_steps),
391391
# while other schedulers can use custom sigmas if supported
392392
is_lcm = self.scheduler == "lcm"
393393
set_timesteps_sig = inspect.signature(scheduler.set_timesteps)
394394
if not is_lcm and "sigmas" in set_timesteps_sig.parameters:
395-
# Convert sigmas list to tensor for scheduler
396395
scheduler.set_timesteps(sigmas=sigmas, device=device)
397396
else:
398397
# LCM or scheduler doesn't support custom sigmas - use num_inference_steps
@@ -644,10 +643,8 @@ def _run_diffusion(self, context: InvocationContext) -> torch.Tensor:
644643
),
645644
)
646645
else:
647-
# For LCM and other first-order schedulers
646+
# For first-order schedulers (Euler, LCM)
648647
user_step += 1
649-
# Only call step_callback if we haven't exceeded total_steps
650-
# (LCM scheduler may have more internal steps than user-facing steps)
651648
if user_step <= total_steps:
652649
pbar.update(1)
653650
step_callback(

invokeai/app/invocations/z_image_lora_loader.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,18 @@ def invoke(self, context: InvocationContext) -> ZImageLoRALoaderOutput:
6666
if self.qwen3_encoder and any(lora.lora.key == lora_key for lora in self.qwen3_encoder.loras):
6767
raise ValueError(f'LoRA "{lora_key}" already applied to Qwen3 encoder.')
6868

69+
# Warn on variant mismatch between LoRA and transformer.
70+
lora_config = context.models.get_config(lora_key)
71+
lora_variant = getattr(lora_config, "variant", None)
72+
if lora_variant and self.transformer is not None:
73+
transformer_config = context.models.get_config(self.transformer.transformer.key)
74+
transformer_variant = getattr(transformer_config, "variant", None)
75+
if transformer_variant and lora_variant != transformer_variant:
76+
context.logger.warning(
77+
f"LoRA variant mismatch: LoRA '{lora_config.name}' is for {lora_variant.value} "
78+
f"but transformer is {transformer_variant.value}. This may cause unexpected results."
79+
)
80+
6981
output = ZImageLoRALoaderOutput()
7082

7183
# Attach LoRA layers to the models.
@@ -142,6 +154,18 @@ def invoke(self, context: InvocationContext) -> ZImageLoRALoaderOutput:
142154
"not Z-Image models. Ensure you are using a Z-Image compatible LoRA."
143155
)
144156

157+
# Warn on variant mismatch between LoRA and transformer.
158+
lora_config = context.models.get_config(lora.lora.key)
159+
lora_variant = getattr(lora_config, "variant", None)
160+
if lora_variant and self.transformer is not None:
161+
transformer_config = context.models.get_config(self.transformer.transformer.key)
162+
transformer_variant = getattr(transformer_config, "variant", None)
163+
if transformer_variant and lora_variant != transformer_variant:
164+
context.logger.warning(
165+
f"LoRA variant mismatch: LoRA '{lora_config.name}' is for {lora_variant.value} "
166+
f"but transformer is {transformer_variant.value}. This may cause unexpected results."
167+
)
168+
145169
added_loras.append(lora.lora.key)
146170

147171
if self.transformer is not None and output.transformer is not None:

invokeai/app/services/model_records/model_records_base.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
ModelVariantType,
2828
Qwen3VariantType,
2929
SchedulerPredictionType,
30+
ZImageVariantType,
3031
)
3132

3233

@@ -92,9 +93,9 @@ class ModelRecordChanges(BaseModelExcludeNull):
9293

9394
# Checkpoint-specific changes
9495
# TODO(MM2): Should we expose these? Feels footgun-y...
95-
variant: Optional[ModelVariantType | ClipVariantType | FluxVariantType | Flux2VariantType | Qwen3VariantType] = (
96-
Field(description="The variant of the model.", default=None)
97-
)
96+
variant: Optional[
97+
ModelVariantType | ClipVariantType | FluxVariantType | Flux2VariantType | ZImageVariantType | Qwen3VariantType
98+
] = Field(description="The variant of the model.", default=None)
9899
prediction_type: Optional[SchedulerPredictionType] = Field(
99100
description="The prediction type of the model.", default=None
100101
)

invokeai/app/services/shared/sqlite/sqlite_util.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_23 import build_migration_23
2929
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_24 import build_migration_24
3030
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_25 import build_migration_25
31+
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_26 import build_migration_26
3132
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_impl import SqliteMigrator
3233

3334

@@ -73,6 +74,7 @@ def init_db(config: InvokeAIAppConfig, logger: Logger, image_files: ImageFileSto
7374
migrator.register_migration(build_migration_23(app_config=config, logger=logger))
7475
migrator.register_migration(build_migration_24(app_config=config, logger=logger))
7576
migrator.register_migration(build_migration_25(app_config=config, logger=logger))
77+
migrator.register_migration(build_migration_26(app_config=config, logger=logger))
7678
migrator.run_migrations()
7779

7880
return db
Lines changed: 115 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import json
2+
import sqlite3
3+
from logging import Logger
4+
from pathlib import Path
5+
from typing import Any
6+
7+
from invokeai.app.services.config import InvokeAIAppConfig
8+
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
9+
from invokeai.backend.model_manager.taxonomy import BaseModelType, ModelFormat, ModelType, ZImageVariantType
10+
11+
12+
class Migration26Callback:
13+
def __init__(self, app_config: InvokeAIAppConfig, logger: Logger) -> None:
14+
self._app_config = app_config
15+
self._logger = logger
16+
17+
def _detect_variant_from_scheduler(self, model_path: Path) -> ZImageVariantType:
18+
"""Detect Z-Image variant from scheduler config for Diffusers models.
19+
20+
Z-Image variants are distinguished by the scheduler shift value:
21+
- Turbo (distilled): shift = 3.0
22+
- Base (undistilled): shift = 6.0
23+
"""
24+
scheduler_config_path = model_path / "scheduler" / "scheduler_config.json"
25+
26+
if not scheduler_config_path.exists():
27+
return ZImageVariantType.Turbo
28+
29+
try:
30+
with open(scheduler_config_path, "r", encoding="utf-8") as f:
31+
scheduler_config = json.load(f)
32+
33+
shift = scheduler_config.get("shift", 3.0)
34+
35+
# ZBase (undistilled) uses shift = 6.0, Turbo uses shift = 3.0
36+
if shift >= 5.0:
37+
return ZImageVariantType.ZBase
38+
else:
39+
return ZImageVariantType.Turbo
40+
except (json.JSONDecodeError, OSError) as e:
41+
self._logger.warning(f"Could not read scheduler config: {e}, defaulting to Turbo")
42+
return ZImageVariantType.Turbo
43+
44+
def __call__(self, cursor: sqlite3.Cursor) -> None:
45+
cursor.execute("SELECT id, config FROM models;")
46+
rows = cursor.fetchall()
47+
48+
migrated_turbo = 0
49+
migrated_base = 0
50+
51+
for model_id, config_json in rows:
52+
try:
53+
config_dict: dict[str, Any] = json.loads(config_json)
54+
55+
# Only migrate Z-Image main models
56+
if config_dict.get("base") != BaseModelType.ZImage.value:
57+
continue
58+
59+
if config_dict.get("type") != ModelType.Main.value:
60+
continue
61+
62+
# Skip if variant already set
63+
if "variant" in config_dict:
64+
continue
65+
66+
# Determine variant based on format
67+
model_format = config_dict.get("format")
68+
model_path = config_dict.get("path")
69+
70+
if model_format == ModelFormat.Diffusers.value and model_path:
71+
# For Diffusers models, detect from scheduler config
72+
variant = self._detect_variant_from_scheduler(Path(model_path))
73+
else:
74+
# For Checkpoint/GGUF, default to Turbo (Base only available as Diffusers)
75+
variant = ZImageVariantType.Turbo
76+
77+
config_dict["variant"] = variant.value
78+
79+
cursor.execute(
80+
"UPDATE models SET config = ? WHERE id = ?;",
81+
(json.dumps(config_dict), model_id),
82+
)
83+
84+
if variant == ZImageVariantType.ZBase:
85+
migrated_base += 1
86+
else:
87+
migrated_turbo += 1
88+
89+
except json.JSONDecodeError as e:
90+
self._logger.error("Invalid config JSON for model %s: %s", model_id, e)
91+
raise
92+
93+
total = migrated_turbo + migrated_base
94+
if total > 0:
95+
self._logger.info(
96+
f"Migration complete: {total} Z-Image model configs updated "
97+
f"({migrated_turbo} Turbo, {migrated_base} Base)"
98+
)
99+
else:
100+
self._logger.info("Migration complete: no Z-Image model configs needed migration")
101+
102+
103+
def build_migration_26(app_config: InvokeAIAppConfig, logger: Logger) -> Migration:
104+
"""Builds the migration object for migrating from version 25 to version 26.
105+
106+
This migration adds the variant field to existing Z-Image main models.
107+
Models installed before the variant field was added will default to Turbo
108+
(the only variant available before Z-Image Base support was added).
109+
"""
110+
111+
return Migration(
112+
from_version=25,
113+
to_version=26,
114+
callback=Migration26Callback(app_config=app_config, logger=logger),
115+
)

invokeai/backend/flux/schedulers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@
4040
FLUX_SCHEDULER_MAP["lcm"] = FlowMatchLCMScheduler
4141

4242

43-
# Z-Image scheduler types (same schedulers as Flux, both use Flow Matching)
44-
# Note: Z-Image-Turbo is optimized for ~8 steps with Euler, but other schedulers
45-
# can be used for experimentation.
43+
# Z-Image scheduler types (Flow Matching schedulers)
44+
# Note: Z-Image-Turbo is optimized for ~8 steps with Euler, LCM can also work.
45+
# Z-Image Base (undistilled) should only use Euler or Heun (LCM not supported for undistilled models).
4646
ZIMAGE_SCHEDULER_NAME_VALUES = Literal["euler", "heun", "lcm"]
4747

4848
# Human-readable labels for the UI
@@ -52,7 +52,7 @@
5252
"lcm": "LCM",
5353
}
5454

55-
# Mapping from scheduler names to scheduler classes (same as Flux)
55+
# Mapping from scheduler names to scheduler classes
5656
ZIMAGE_SCHEDULER_MAP: dict[str, Type[SchedulerMixin]] = {
5757
"euler": FlowMatchEulerDiscreteScheduler,
5858
"heun": FlowMatchHeunDiscreteScheduler,

invokeai/backend/model_manager/configs/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ def get_tag(cls) -> Tag:
128128
# defaults. variant does not require a default, but if it has one, we need to add it to the tag. We can
129129
# check for the presence of a default by seeing if it's not PydanticUndefined, a sentinel value used by
130130
# pydantic to indicate that no default was provided.
131-
if field.default is not PydanticUndefined:
131+
if field.default is not PydanticUndefined and field.default is not None:
132132
# We expect each of these fields has an Enum for its default; we want the value of the enum.
133133
tag_strings.append(field.default.value)
134134
return Tag(".".join(tag_strings))

invokeai/backend/model_manager/configs/lora.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
FluxLoRAFormat,
2828
ModelFormat,
2929
ModelType,
30+
ZImageVariantType,
3031
)
3132
from invokeai.backend.model_manager.util.model_util import lora_token_vector_length
3233
from invokeai.backend.patches.lora_conversions.flux_control_lora_utils import is_state_dict_likely_flux_control
@@ -226,6 +227,7 @@ class LoRA_LyCORIS_ZImage_Config(LoRA_LyCORIS_Config_Base, Config_Base):
226227
"""Model config for Z-Image LoRA models in LyCORIS format."""
227228

228229
base: Literal[BaseModelType.ZImage] = Field(default=BaseModelType.ZImage)
230+
variant: ZImageVariantType | None = Field(default=None)
229231

230232
@classmethod
231233
def _validate_looks_like_lora(cls, mod: ModelOnDisk) -> None:
@@ -398,3 +400,4 @@ class LoRA_Diffusers_ZImage_Config(LoRA_Diffusers_Config_Base, Config_Base):
398400
"""Model config for Z-Image LoRA models in Diffusers format."""
399401

400402
base: Literal[BaseModelType.ZImage] = Field(default=BaseModelType.ZImage)
403+
variant: ZImageVariantType | None = Field(default=None)

invokeai/backend/model_manager/configs/main.py

Lines changed: 41 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
ModelVariantType,
3131
SchedulerPredictionType,
3232
SubModelType,
33+
ZImageVariantType,
3334
)
3435
from invokeai.backend.quantization.gguf.ggml_tensor import GGMLTensor
3536
from invokeai.backend.stable_diffusion.schedulers.schedulers import SCHEDULER_NAME_VALUES
@@ -57,7 +58,7 @@ class MainModelDefaultSettings(BaseModel):
5758
def from_base(
5859
cls,
5960
base: BaseModelType,
60-
variant: Flux2VariantType | FluxVariantType | ModelVariantType | None = None,
61+
variant: Flux2VariantType | FluxVariantType | ModelVariantType | ZImageVariantType | None = None,
6162
) -> Self | None:
6263
match base:
6364
case BaseModelType.StableDiffusion1:
@@ -67,7 +68,14 @@ def from_base(
6768
case BaseModelType.StableDiffusionXL:
6869
return cls(width=1024, height=1024)
6970
case BaseModelType.ZImage:
70-
return cls(steps=9, cfg_scale=1.0, width=1024, height=1024)
71+
# Different defaults based on variant
72+
if variant == ZImageVariantType.ZBase:
73+
# Undistilled base model needs more steps and supports CFG
74+
# Recommended: steps=28-50, cfg_scale=3.0-5.0
75+
return cls(steps=50, cfg_scale=4.0, width=1024, height=1024)
76+
else:
77+
# Turbo (distilled) uses fewer steps, no CFG
78+
return cls(steps=9, cfg_scale=1.0, width=1024, height=1024)
7179
case BaseModelType.Flux2:
7280
# Different defaults based on variant
7381
if variant == Flux2VariantType.Klein9BBase:
@@ -1077,9 +1085,10 @@ def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -
10771085

10781086

10791087
class Main_Diffusers_ZImage_Config(Diffusers_Config_Base, Main_Config_Base, Config_Base):
1080-
"""Model config for Z-Image diffusers models (Z-Image-Turbo, Z-Image-Base, Z-Image-Edit)."""
1088+
"""Model config for Z-Image diffusers models (Z-Image-Turbo, Z-Image-Base)."""
10811089

10821090
base: Literal[BaseModelType.ZImage] = Field(BaseModelType.ZImage)
1091+
variant: ZImageVariantType = Field()
10831092

10841093
@classmethod
10851094
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
@@ -1095,19 +1104,41 @@ def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -
10951104
},
10961105
)
10971106

1107+
variant = override_fields.get("variant") or cls._get_variant_or_raise(mod)
1108+
10981109
repo_variant = override_fields.get("repo_variant") or cls._get_repo_variant_or_raise(mod)
10991110

11001111
return cls(
11011112
**override_fields,
1113+
variant=variant,
11021114
repo_variant=repo_variant,
11031115
)
11041116

1117+
@classmethod
1118+
def _get_variant_or_raise(cls, mod: ModelOnDisk) -> ZImageVariantType:
1119+
"""Determine Z-Image variant from the scheduler config.
1120+
1121+
Z-Image variants are distinguished by the scheduler shift value:
1122+
- Turbo (distilled): shift = 3.0
1123+
- Base (undistilled): shift = 6.0
1124+
"""
1125+
scheduler_config = get_config_dict_or_raise(mod.path / "scheduler" / "scheduler_config.json")
1126+
1127+
shift = scheduler_config.get("shift", 3.0)
1128+
1129+
# ZBase (undistilled) uses shift = 6.0, Turbo uses shift = 3.0
1130+
if shift >= 5.0:
1131+
return ZImageVariantType.ZBase
1132+
else:
1133+
return ZImageVariantType.Turbo
1134+
11051135

11061136
class Main_Checkpoint_ZImage_Config(Checkpoint_Config_Base, Main_Config_Base, Config_Base):
11071137
"""Model config for Z-Image single-file checkpoint models (safetensors, etc)."""
11081138

11091139
base: Literal[BaseModelType.ZImage] = Field(default=BaseModelType.ZImage)
11101140
format: Literal[ModelFormat.Checkpoint] = Field(default=ModelFormat.Checkpoint)
1141+
variant: ZImageVariantType = Field()
11111142

11121143
@classmethod
11131144
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
@@ -1119,7 +1150,9 @@ def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -
11191150

11201151
cls._validate_does_not_look_like_gguf_quantized(mod)
11211152

1122-
return cls(**override_fields)
1153+
variant = override_fields.get("variant", ZImageVariantType.Turbo)
1154+
1155+
return cls(**override_fields, variant=variant)
11231156

11241157
@classmethod
11251158
def _validate_looks_like_z_image_model(cls, mod: ModelOnDisk) -> None:
@@ -1139,6 +1172,7 @@ class Main_GGUF_ZImage_Config(Checkpoint_Config_Base, Main_Config_Base, Config_B
11391172

11401173
base: Literal[BaseModelType.ZImage] = Field(default=BaseModelType.ZImage)
11411174
format: Literal[ModelFormat.GGUFQuantized] = Field(default=ModelFormat.GGUFQuantized)
1175+
variant: ZImageVariantType = Field()
11421176

11431177
@classmethod
11441178
def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -> Self:
@@ -1150,7 +1184,9 @@ def from_model_on_disk(cls, mod: ModelOnDisk, override_fields: dict[str, Any]) -
11501184

11511185
cls._validate_looks_like_gguf_quantized(mod)
11521186

1153-
return cls(**override_fields)
1187+
variant = override_fields.get("variant", ZImageVariantType.Turbo)
1188+
1189+
return cls(**override_fields, variant=variant)
11541190

11551191
@classmethod
11561192
def _validate_looks_like_z_image_model(cls, mod: ModelOnDisk) -> None:

0 commit comments

Comments
 (0)