Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions src/tap/tapify.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@

import dataclasses
import inspect
from typing import Any, Callable, Optional, Sequence, TypeVar
from types import SimpleNamespace

# TODO: 3.11 use only Annotated to combine pydantic metadata
from typing import Any, Callable, Optional, Sequence, TypeVar, _AnnotatedAlias, get_type_hints

from docstring_parser import Docstring, parse

Expand All @@ -22,8 +25,8 @@
else:
_IS_PYDANTIC_V1 = pydantic.VERSION.startswith("1.")
from pydantic import BaseModel
from pydantic.fields import FieldInfo as PydanticFieldBaseModel
from pydantic.dataclasses import FieldInfo as PydanticFieldDataclass
from pydantic.fields import FieldInfo as PydanticFieldBaseModel

_PydanticField = PydanticFieldBaseModel | PydanticFieldDataclass
# typing.get_args(_PydanticField) is an empty tuple for some reason. Just repeat
Expand Down Expand Up @@ -59,6 +62,9 @@ class _ArgData:
is_positional_only: bool = False
"Whether or not the argument must be provided positionally"

pydantic_metadata: Optional[tuple[Any]] = None
"Additional metadata from Annotated fields in Pydantic models"""


@dataclasses.dataclass(frozen=True)
class _TapData:
Expand Down Expand Up @@ -127,7 +133,7 @@ def arg_data_from_pydantic(name: str, field: _PydanticField, annotation: Optiona
# Prefer the description from param_to_description (from the data model / class docstring) over the
# field.description b/c a docstring can be modified on the fly w/o causing real issues
description = param_to_description.get(name, field.description)
return _ArgData(name, annotation, field.is_required(), field.default, description)
return _ArgData(name, annotation, field.is_required(), field.default, description, pydantic_metadata=tuple(field.metadata))

# Determine what type of data model it is and extract fields accordingly
if dataclasses.is_dataclass(data_model):
Expand Down Expand Up @@ -269,6 +275,9 @@ def _tap_data(class_or_function: _ClassOrFunction, param_to_description: dict[st
# TODO: allow passing func_kwargs to a Pydantic BaseModel
return _tap_data_from_class_or_function(class_or_function, func_kwargs, param_to_description)

def _remove_extras_from_annotation(annotation):
"""Removes extras from annotation types, e.g., Annotated, etc."""
return get_type_hints(SimpleNamespace(__annotations__ = {"dummy": annotation}))["dummy"]

def _tap_class(args_data: Sequence[_ArgData]) -> type[Tap]:
"""
Expand All @@ -282,7 +291,19 @@ def _configure(self):
for arg_data in args_data:
variable = arg_data.name
if variable not in self.class_variables:
self._annotations[variable] = str if arg_data.annotation is Any else arg_data.annotation
annotation = str if arg_data.annotation is Any else arg_data.annotation
if arg_data.pydantic_metadata:

# Pydantic does clean Annotated metadata, so we need to add it here
# Make sure we have an _AnnotatedAlias and add the fields metadata to it
# TODO: 3.11 use star expression:
# annotation = Annotated[annotation, *arg_data.metadata]
if not isinstance(annotation, _AnnotatedAlias):
annotation = _AnnotatedAlias(annotation, arg_data.pydantic_metadata)
else:
annotation.__metadata__ = (*annotation.__metadata__, *arg_data.pydantic_metadata)
self._annotations_with_extras[variable] = annotation
self._annotations[variable] = _remove_extras_from_annotation(annotation)
self.class_variables[variable] = {"comment": arg_data.description or ""}
if arg_data.is_required:
kwargs = {}
Expand Down
11 changes: 10 additions & 1 deletion tests/test_to_tap_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import io
import re
import sys
from typing import Any, Callable, List, Literal, Optional, Type, Union
from typing import Annotated, Any, Callable, List, Literal, Optional, Type, Union

import pytest

Expand Down Expand Up @@ -688,3 +688,12 @@ class TapGrandchild(to_tap_class(TapChild)):
args = ["--d", "4"]
with pytest.raises(SystemExit):
TapGrandchild().parse_args(args)

def test_extras_removal():
class Parent:
def __init__(self, an_int: Annotated[int, "metadata"] = 1):
pass

tapped = to_tap_class(Parent)
assert tapped()._annotations["an_int"] == int
assert tapped()._annotations_with_extras["an_int"] == Annotated[int, "metadata"]