Skip to content

Commit 1825f12

Browse files
committed
✨ support schema generate for configs in $files
1 parent 0c2f189 commit 1825f12

File tree

5 files changed

+482
-380
lines changed

5 files changed

+482
-380
lines changed

arclet/entari/config/file.py

Lines changed: 72 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,7 @@ class EntariConfig:
120120
plugin_extra_files: list[str] = field(default_factory=list, init=False)
121121
save_flag: bool = field(default=False)
122122
_origin_data: dict[str, Any] = field(init=False)
123-
_env_replaced: dict[int, str] = field(default_factory=dict, init=False)
123+
_env_replaced: dict[str, dict[int, str]] = field(default_factory=dict, init=False)
124124

125125
instance: ClassVar["EntariConfig"]
126126

@@ -136,7 +136,7 @@ def loader(self, path: Path):
136136
for i, line in enumerate(lines):
137137

138138
def handle(m: re.Match[str]):
139-
self._env_replaced[i] = line
139+
self._env_replaced.setdefault(path.as_posix(), {})[i] = line
140140
expr = m.group("expr")
141141
return safe_eval(expr, ctx)
142142

@@ -160,9 +160,9 @@ def dumper(self, path: Path, save_path: Path, data: dict, indent: int, apply_sch
160160
schema_file = f"{save_path.stem}.schema.json"
161161
if end in _dumpers:
162162
ans, applied = _dumpers[end](origin, indent, schema_file)
163-
if self._env_replaced:
163+
if path.as_posix() in self._env_replaced:
164164
lines = ans.splitlines(keepends=True)
165-
for i, line in self._env_replaced.items():
165+
for i, line in self._env_replaced[path.as_posix()].items():
166166
lines[i + applied] = line
167167
ans = "".join(lines)
168168
with save_path.open("w", encoding="utf-8") as f:
@@ -222,10 +222,10 @@ def reload(self):
222222
raise FileNotFoundError(file)
223223
if path.is_dir():
224224
for _path in path.iterdir():
225-
if not _path.is_file():
225+
if not _path.is_file() or _path.name.endswith(".schema.json"):
226226
continue
227227
self.plugin[_path.stem] = self.loader(_path)
228-
else:
228+
elif path.name.endswith(".schema.json"):
229229
self.plugin[path.stem] = self.loader(path)
230230
return True
231231

@@ -235,7 +235,7 @@ def _clean(value: T_M) -> T_M:
235235
value.pop("$static", None)
236236
return value
237237

238-
def dump(self, indent: int = 2):
238+
def dump(self, indent: int = 2, apply_schema: bool = False) -> dict[str, Any]:
239239
basic = self._origin_data.get("basic", {})
240240
if "log" not in basic and ("log_level" in basic or "log_ignores" in basic):
241241
basic["log"] = {}
@@ -247,12 +247,12 @@ def dump(self, indent: int = 2):
247247
if self.plugin_extra_files:
248248
for file in self.plugin_extra_files:
249249
path = Path(file)
250-
if path.is_file():
251-
self.dumper(path, path, self._clean(self.plugin.pop(path.stem)), indent, False)
250+
if path.is_file() and not path.name.endswith(".schema.json"):
251+
self.dumper(path, path, self._clean(self.plugin.pop(path.stem)), indent, apply_schema)
252252
else:
253253
for _path in path.iterdir():
254-
if _path.is_file():
255-
self.dumper(_path, _path, self._clean(self.plugin.pop(_path.stem)), indent, False)
254+
if _path.is_file() and not _path.name.endswith(".schema.json"):
255+
self.dumper(_path, _path, self._clean(self.plugin.pop(_path.stem)), indent, apply_schema)
256256
for key in list(self.plugin.keys()):
257257
if key.startswith("$"):
258258
continue
@@ -268,7 +268,7 @@ def dump(self, indent: int = 2):
268268

269269
def save(self, path: str | os.PathLike[str] | None = None, indent: int = 2, apply_schema: bool = False):
270270
self.save_flag = True
271-
self.dumper(self.path, Path(path or self.path), self.dump(indent), indent, apply_schema)
271+
self.dumper(self.path, Path(path or self.path), self.dump(indent, apply_schema), indent, apply_schema)
272272

273273
@classmethod
274274
def load(cls, path: str | os.PathLike[str] | None = None) -> "EntariConfig":
@@ -322,7 +322,10 @@ def generate_schema(self, plugins: list["Plugin"]):
322322
plugins_properties = {}
323323
# fmt: off
324324
plugin_meta_properties = {"$disable": {"type": "string", "description": "Expression for whether disable this plugin"}, "$prefix": {"type": "string", "description": "Plugin name prefix"}, "$priority": {"type": "integer", "description": "Plugin loading priority, lower value means higher priority (default: 16)"}, "$filter": {"type": "string", "description": "Plugin filter expression, which will be evaluated in the context of the plugin"}} # noqa: E501
325+
# Build a mapping from plugin config key to plugin object for $files schema generation
326+
plugin_map: dict[str, "Plugin"] = {} # noqa: UP037
325327
for plug in plugins:
328+
plugin_map[plug._config_key] = plug
326329
if plug.metadata is not None:
327330
if plug.metadata.config:
328331
schema = config_model_schema(plug.metadata.config, ref_root=f"/properties/plugins/properties/{plug._config_key}/") # noqa: E501
@@ -337,7 +340,64 @@ def generate_schema(self, plugins: list["Plugin"]):
337340
}
338341
with open(f"{self.path.stem}.schema.json", "w", encoding="utf-8") as f:
339342
json.dump({"$schema": "https://json-schema.org/draft/2020-12/schema", "type": "object", "properties": schemas, "additionalProperties": False, "required": ["basic"]}, f, indent=2, ensure_ascii=False) # noqa: E501
343+
344+
# Generate schema for each file in $files
345+
for file in self.plugin_extra_files:
346+
path = Path(file)
347+
if path.is_file() and not path.name.endswith(".schema.json"):
348+
self._generate_extra_file_schema(path, plugin_map, plugin_meta_properties)
349+
elif path.is_dir():
350+
for _path in path.iterdir():
351+
if _path.is_file() and not _path.name.endswith(".schema.json"):
352+
self._generate_extra_file_schema(_path, plugin_map, plugin_meta_properties)
340353
# fmt: on
354+
plugin_map.clear()
355+
356+
def _generate_extra_file_schema(self, path: Path, plugin_map: dict[str, "Plugin"], plugin_meta_properties: dict):
357+
"""Generate schema for an extra config file from $files."""
358+
plugin_key = path.stem
359+
schema_file = path.with_suffix(".schema.json")
360+
plugin_meta_properties = {
361+
**plugin_meta_properties,
362+
"$optional": {"type": "boolean", "description": "Whether this plugin is optional"},
363+
} # noqa: E501
364+
365+
# Check if we have a matching plugin with config
366+
if plugin_key in plugin_map:
367+
plug = plugin_map[plugin_key]
368+
if plug.metadata is not None and plug.metadata.config:
369+
plugin_schema = config_model_schema(plug.metadata.config, ref_root="/")
370+
plugin_schema["properties"].update(plugin_meta_properties)
371+
elif plug.metadata is not None:
372+
plugin_schema = {
373+
"type": "object",
374+
"description": f"{plug.metadata.description or plug.metadata.name}; no configuration required",
375+
"additionalProperties": True,
376+
"properties": plugin_meta_properties,
377+
} # noqa: E501
378+
else:
379+
plugin_schema = {
380+
"type": "object",
381+
"description": "No configuration required",
382+
"additionalProperties": True,
383+
"properties": plugin_meta_properties,
384+
} # noqa: E501
385+
else:
386+
# Plugin not found, generate a generic schema
387+
plugin_schema = {
388+
"type": "object",
389+
"description": f"Configuration for {plugin_key}",
390+
"additionalProperties": True,
391+
"properties": plugin_meta_properties,
392+
} # noqa: E501
393+
394+
with open(schema_file, "w", encoding="utf-8") as f:
395+
json.dump(
396+
{"$schema": "https://json-schema.org/draft/2020-12/schema", **plugin_schema},
397+
f,
398+
indent=2,
399+
ensure_ascii=False,
400+
) # noqa: E501
341401

342402

343403
load_config = EntariConfig.load

arclet/entari/core.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -111,80 +111,71 @@ async def __call__(self, context: Contexts):
111111
if ITEM_OPERATOR in context:
112112
return context[ITEM_OPERATOR]
113113
if ITEM_ORIGIN_EVENT not in context:
114-
return
115-
return context[ITEM_ORIGIN_EVENT].operator
114+
return context[ITEM_ORIGIN_EVENT].operator
116115

117116

118117
class UserProvider(Provider[User]):
119118
async def __call__(self, context: Contexts):
120119
if ITEM_USER in context:
121120
return context[ITEM_USER]
122121
if ITEM_ORIGIN_EVENT not in context:
123-
return
124-
return context[ITEM_ORIGIN_EVENT].user
122+
return context[ITEM_ORIGIN_EVENT].user
125123

126124

127125
class MessageProvider(Provider[MessageObject]):
128126
async def __call__(self, context: Contexts):
129127
if ITEM_MESSAGE_ORIGIN in context:
130128
return context[ITEM_MESSAGE_ORIGIN]
131129
if ITEM_ORIGIN_EVENT not in context:
132-
return
133-
return context[ITEM_ORIGIN_EVENT].message
130+
return context[ITEM_ORIGIN_EVENT].message
134131

135132

136133
class ChannelProvider(Provider[Channel]):
137134
async def __call__(self, context: Contexts):
138135
if ITEM_CHANNEL in context:
139136
return context[ITEM_CHANNEL]
140137
if ITEM_ORIGIN_EVENT not in context:
141-
return
142-
return context[ITEM_ORIGIN_EVENT].channel
138+
return context[ITEM_ORIGIN_EVENT].channel
143139

144140

145141
class GuildProvider(Provider[Guild]):
146142
async def __call__(self, context: Contexts):
147143
if ITEM_GUILD in context:
148144
return context[ITEM_GUILD]
149145
if ITEM_ORIGIN_EVENT not in context:
150-
return
151-
return context[ITEM_ORIGIN_EVENT].guild
146+
return context[ITEM_ORIGIN_EVENT].guild
152147

153148

154149
class MemberProvider(Provider[Member]):
155150
async def __call__(self, context: Contexts):
156151
if ITEM_MEMBER in context:
157152
return context[ITEM_MEMBER]
158153
if ITEM_ORIGIN_EVENT not in context:
159-
return
160-
return context[ITEM_ORIGIN_EVENT].member
154+
return context[ITEM_ORIGIN_EVENT].member
161155

162156

163157
class RoleProvider(Provider[Role]):
164158
async def __call__(self, context: Contexts):
165159
if ITEM_ROLE in context:
166160
return context[ITEM_ROLE]
167161
if ITEM_ORIGIN_EVENT not in context:
168-
return
169-
return context[ITEM_ORIGIN_EVENT].role
162+
return context[ITEM_ORIGIN_EVENT].role
170163

171164

172165
class EmojiProvider(Provider[EmojiObject]):
173166
async def __call__(self, context: Contexts):
174167
if ITEM_EMOJI_OBJECT in context:
175168
return context[ITEM_EMOJI_OBJECT]
176169
if ITEM_ORIGIN_EVENT not in context:
177-
return
178-
return context[ITEM_ORIGIN_EVENT].emoji
170+
return context[ITEM_ORIGIN_EVENT].emoji
179171

180172

181173
class LoginProvider(Provider[Login]):
182174
async def __call__(self, context: Contexts):
183175
if ITEM_LOGIN in context:
184176
return context[ITEM_LOGIN]
185177
if ITEM_ORIGIN_EVENT not in context:
186-
return
187-
return context[ITEM_ORIGIN_EVENT].login
178+
return context[ITEM_ORIGIN_EVENT].login
188179

189180

190181
class MessageContentProvider(Provider[MessageChain]):

arclet/entari/plugin/module.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,14 @@
22
from collections.abc import Sequence
33
from importlib import _bootstrap, _bootstrap_external # type: ignore
44
from importlib.abc import MetaPathFinder
5-
from importlib.machinery import ExtensionFileLoader, PathFinder, SourceFileLoader
5+
from importlib.machinery import ExtensionFileLoader, ModuleSpec, PathFinder, SourceFileLoader
66
from importlib.util import module_from_spec, resolve_name
77
from io import BytesIO
88
import re
99
import sys
1010
import tokenize
1111
from types import ModuleType
12+
from typing import Any
1213

1314
from arclet.letoderea import publish
1415
from arclet.letoderea.scope import scope_ctx
@@ -241,7 +242,7 @@ def create_module(self, spec) -> ModuleType | None:
241242
raise ReusablePluginError(f"reusable plugin {self.name!r} cannot be imported directly")
242243
return super().create_module(spec)
243244

244-
def exec_module(self, module: ModuleType, config: dict[str, str] | None = None) -> None:
245+
def exec_module(self, module: ModuleType, config: dict[str, Any] | None = None) -> None:
245246
is_sub = False
246247
if plugin := plugin_service.plugins.get(self.parent_plugin_id) if self.parent_plugin_id else None:
247248
plugin.subplugins.add(self.plugin_id)
@@ -335,7 +336,7 @@ def _path_find_spec(fullname, path=None, target=None):
335336
"""
336337
if path is None:
337338
path = sys.path
338-
spec = PathFinder._get_spec(fullname, path, target) # type: ignore
339+
spec: ModuleSpec | None = PathFinder._get_spec(fullname, path, target) # type: ignore
339340
if spec is None:
340341
return None
341342
elif spec.loader is None:
@@ -418,12 +419,13 @@ def find_spec(
418419
return
419420

420421

421-
def find_spec(id_, package=None):
422+
def find_spec(id_, package=None) -> ModuleSpec | None:
422423
uid_index = id_.rfind("@")
423424
name = id_ if uid_index == -1 else id_[:uid_index]
424425
fullname = resolve_name(name, package) if name.startswith(".") else name
425426
parent_name = fullname.rpartition(".")[0]
426427
if parent_name:
428+
parent: ModuleType | None
427429
parts = parent_name.split(".")
428430
_current = parts[0]
429431
if _current in plugin_service.plugins:
@@ -454,13 +456,12 @@ def find_spec(id_, package=None):
454456
enter_plugin = False
455457
parent = __import__(_current, fromlist=["__path__"])
456458
_current += "."
457-
try:
458-
parent_path = parent.__path__
459-
except AttributeError as e:
459+
if parent is None:
460460
raise ModuleNotFoundError(
461-
f"__path__ attribute not found on {parent_name!r} " f"while trying to find {fullname!r}",
461+
f"parent module {parent_name!r} does not have __path__ attribute " f"while trying to find {fullname!r}",
462462
name=fullname,
463-
) from e
463+
)
464+
parent_path = parent.__path__
464465
else:
465466
parent_path = None
466467
if isinstance(parent_path, _bootstrap_external._NamespacePath): # type: ignore

0 commit comments

Comments
 (0)