Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions modelsearch/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,19 +298,20 @@ def check(self):
"""
Checks that the search query satisfies the following conditions:

1. All field names passed in the ``fields`` parameter exist as ``SearchField`` records on the model.
1. All field names passed in the ``fields`` parameter exist as searchable fields on the model.
2. All fields used within filters on the passed ``queryset`` exist as ``FilterField`` records on the model.
3. The ``order_by`` clause on the passed ``queryset`` does not contain any expressions other than plain field
names and their reversed counterparts (``"some_field"`` and ``"-some_field"``), unless
``HANDLES_ORDER_BY_EXPRESSIONS`` is ``True``.
names and their reversed counterparts (``"some_field"`` and ``"-some_field"``), unless
``HANDLES_ORDER_BY_EXPRESSIONS`` is ``True``.
4. All field names within the ``order_by`` clause on the passed ``queryset`` exist as ``FilterField`` records
on the model.
on the model.
"""
# Check search fields
if self.fields:
# allowed_fields now uses full lookup names
allowed_fields = {
field.field_name
for field in self.queryset.model.get_searchable_search_fields()
full_name
for _, full_name in self.queryset.model.get_searchable_search_fields_with_relatives()
}

for field_name in self.fields:
Expand Down
3 changes: 2 additions & 1 deletion modelsearch/backends/database/fallback.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def __init__(self, *args, **kwargs):
def get_fields_names(self):
model = self.queryset.model
fields_names = self.fields or [
field.field_name for field in model.get_searchable_search_fields()
field.field_name
for field, _full_name in model.get_searchable_search_fields()
]
# Check if the field exists (this will filter out indexed callables)
for field_name in fields_names:
Expand Down
67 changes: 30 additions & 37 deletions modelsearch/backends/database/mysql/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,9 @@

from collections import OrderedDict

from django.db import (
NotSupportedError,
connections,
router,
transaction,
)
from django.db import NotSupportedError, connections, router, transaction
from django.db.models import Case, OuterRef, Subquery, When
from django.db.models.aggregates import Avg, Count
from django.db.models.constants import LOOKUP_SEP
from django.db.models.expressions import F
from django.db.models.fields import BooleanField, FloatField, TextField
from django.db.models.functions.comparison import Cast
Expand Down Expand Up @@ -292,7 +286,9 @@ def __init__(self, *args, **kwargs):

if self.fields is None:
# search over the fields defined on the current model
self.search_fields = local_search_fields
self.search_fields = {
full_name: field for field, full_name in local_search_fields
}
else:
# build a search_fields set from the passed definition,
# which may involve traversing relations
Expand All @@ -307,37 +303,34 @@ def get_config(self, backend):
return backend.config

def get_search_fields_for_model(self):
return self.queryset.model.get_searchable_search_fields()
return self.queryset.model.get_searchable_search_fields_with_relatives()

def get_search_field(self, field_lookup, fields=None):
if fields is None:
fields = self.search_fields
def get_search_field(self, full_name, fields=None, as_tuple=False):
"""
Returns the SearchField (or AutocompleteField) for the given full_name.

if LOOKUP_SEP in field_lookup:
field_lookup, sub_field_name = field_lookup.split(LOOKUP_SEP, 1)
else:
sub_field_name = None

for field in fields:
if (
isinstance(field, self.TARGET_SEARCH_FIELD_TYPE)
and field.field_name == field_lookup
):
return field

# Note: Searching on a specific related field using
# `.search(fields=…)` is not yet supported by Wagtail.
# This method anticipates by already implementing it.
# FIXME: this doesn't work because the list we're looping over comes from
# get_search_fields_for_model, which only returns `SearchField` records, not `RelatedFields`
if (
isinstance(field, RelatedFields)
and field.field_name == field_lookup
and sub_field_name is not None
):
return self.get_search_field(
sub_field_name, field.fields
) # pragma: no cover
:param full_name: the flattened field lookup, e.g., "authors__name"
:param fields: list of tuples (field_obj, full_name) from get_search_fields_for_model
:param as_tuple: if True, return (field_obj, full_lookup_name) tuple
"""
if fields is None:
fields = self.search_fields # could be dict {full_name: field_obj}

# Dict case
if isinstance(fields, dict):
field = fields.get(full_name)
if field is not None and as_tuple:
return field, full_name
return field

# List of tuples case
for field_obj, fname in fields:
if fname == full_name:
if as_tuple:
return field_obj, fname
return field_obj

return None

def build_search_query_content(self, query, invert=False):
if isinstance(query, PlainText):
Expand Down
104 changes: 63 additions & 41 deletions modelsearch/backends/database/postgres/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections import OrderedDict
from functools import reduce

from django.contrib.postgres.aggregates import StringAgg
from django.contrib.postgres.search import (
SearchQuery,
SearchRank,
Expand All @@ -30,7 +31,6 @@
Value,
When,
)
from django.db.models.constants import LOOKUP_SEP
from django.db.models.functions import Cast, Greatest, Length
from django.db.models.sql.subqueries import InsertQuery
from django.utils.encoding import force_str
Expand Down Expand Up @@ -473,7 +473,9 @@ def __init__(self, *args, **kwargs):

if self.fields is None:
# search over the fields defined on the current model
self.search_fields = local_search_fields
self.search_fields = {
full_name: field for field, full_name in local_search_fields
}
else:
# build a search_fields set from the passed definition,
# which may involve traversing relations
Expand All @@ -488,37 +490,34 @@ def get_config(self, backend):
return backend.config

def get_search_fields_for_model(self):
return self.queryset.model.get_searchable_search_fields()
return self.queryset.model.get_searchable_search_fields_with_relatives()

def get_search_field(self, full_name, fields=None, as_tuple=False):
"""
Returns the SearchField (or AutocompleteField) for the given full_name.

def get_search_field(self, field_lookup, fields=None):
:param full_name: the flattened field lookup, e.g., "authors__name"
:param fields: list of tuples (field_obj, full_name) from get_search_fields_for_model
:param as_tuple: if True, return (field_obj, full_lookup_name) tuple
"""
if fields is None:
fields = self.search_fields
fields = self.search_fields # could be dict {full_name: field_obj}

if LOOKUP_SEP in field_lookup:
field_lookup, sub_field_name = field_lookup.split(LOOKUP_SEP, 1)
else:
sub_field_name = None

for field in fields:
if (
isinstance(field, self.TARGET_SEARCH_FIELD_TYPE)
and field.field_name == field_lookup
):
return field

# Note: Searching on a specific related field using
# `.search(fields=…)` is not yet supported by Wagtail.
# This method anticipates by already implementing it.
# FIXME: this doesn't work because the list we're looping over comes from
# get_search_fields_for_model, which only returns `SearchField` records, not `RelatedFields`
if (
isinstance(field, RelatedFields)
and field.field_name == field_lookup
and sub_field_name is not None
):
return self.get_search_field(
sub_field_name, field.fields
) # pragma: no cover
# Dict case
if isinstance(fields, dict):
field = fields.get(full_name)
if field is not None and as_tuple:
return field, full_name
return field

# List of tuples case
for field_obj, fname in fields:
if fname == full_name:
if as_tuple:
return field_obj, fname
return field_obj

return None

def build_tsquery_content(self, query, config=None, invert=False):
if isinstance(query, PlainText):
Expand Down Expand Up @@ -595,7 +594,7 @@ def build_tsquery(self, query, config=None):
return self.build_tsquery_content(query, config=config)

def build_tsrank(self, vector, query, config=None, boost=1.0):
if isinstance(query, (Phrase, PlainText, Not)):
if isinstance(query, Phrase | PlainText | Not):
rank_expression = SearchRank(
vector,
self.build_tsquery(query, config=config),
Expand Down Expand Up @@ -830,17 +829,40 @@ def get_index_vectors(self, search_query):
return [(F("index_entries__autocomplete"), 1.0)]

def get_fields_vectors(self, search_query):
return [
(
SearchVector(
field_lookup,
config=search_query.config,
weight="D",
),
1.0,
vectors = []

for field_lookup, _search_field in self.search_fields.items():
# Get (field_obj, full_lookup_name) tuple safely
result = self.get_search_field(full_name=field_lookup, as_tuple=True)
if result is None:
continue # skip missing fields

root_field_obj, full_name = result

# For autocomplete fields, there is usually no subfield
# We'll treat full_name as the vector source
vector_field = full_name

# Optionally annotate if you want to handle subfields (rare for autocomplete)
annotated_name = f"{root_field_obj.field_name}_{full_name}"
if annotated_name not in self.queryset.query.annotations:
self.queryset = self.queryset.annotate(
**{annotated_name: StringAgg(full_name, delimiter=" ")}
)
vector_field = annotated_name

vectors.append(
(
SearchVector(
vector_field,
config=search_query.config,
weight="D",
),
getattr(root_field_obj, "boost", 1.0),
)
)
for field_lookup, search_field in self.search_fields.items()
]

return vectors


class PostgresSearchResults(BaseSearchResults):
Expand Down
59 changes: 23 additions & 36 deletions modelsearch/backends/database/sqlite/sqlite.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
transaction,
)
from django.db.models import Avg, Case, Count, F, Manager, TextField, When
from django.db.models.constants import LOOKUP_SEP
from django.db.models.functions import Cast, Length
from django.utils.encoding import force_str
from django.utils.functional import cached_property
Expand All @@ -18,12 +17,7 @@

from ....index import AutocompleteField, RelatedFields, SearchField, get_indexed_models
from ....query import And, MatchAll, Not, Or, Phrase, PlainText
from ....utils import (
ADD,
MUL,
get_content_type_pk,
get_descendants_content_types_pks,
)
from ....utils import ADD, MUL, get_content_type_pk, get_descendants_content_types_pks
from ...base import (
BaseIndex,
BaseSearchBackend,
Expand Down Expand Up @@ -324,7 +318,9 @@ def __init__(self, *args, **kwargs):

if self.fields is None:
# search over the fields defined on the current model
self.search_fields = local_search_fields
self.search_fields = {
full_name: field for field, full_name in local_search_fields
}
else:
# build a search_fields set from the passed definition,
# which may involve traversing relations
Expand All @@ -339,37 +335,28 @@ def get_config(self, backend):
return backend.config

def get_search_fields_for_model(self):
return self.queryset.model.get_searchable_search_fields()
return self.queryset.model.get_searchable_search_fields_with_relatives()

def get_search_field(self, full_name, fields=None):
"""
Returns the SearchField object for the given full_name.

def get_search_field(self, field_lookup, fields=None):
:param full_name: the flattened field lookup, e.g., "authors__name"
:param fields: list of tuples (SearchField, full_name) from get_search_fields_for_model
"""
if fields is None:
fields = self.search_fields
fields = self.search_fields # fallback to the dict {full_name: SearchField}

if LOOKUP_SEP in field_lookup:
field_lookup, sub_field_name = field_lookup.split(LOOKUP_SEP, 1)
else:
sub_field_name = None

for field in fields:
if (
isinstance(field, self.TARGET_SEARCH_FIELD_TYPE)
and field.field_name == field_lookup
):
return field

# Note: Searching on a specific related field using
# `.search(fields=…)` is not yet supported by Wagtail.
# This method anticipates by already implementing it.
# FIXME: this doesn't work because the list we're looping over comes from
# get_search_fields_for_model, which only returns `SearchField` records, not `RelatedFields`
if (
isinstance(field, RelatedFields)
and field.field_name == field_lookup
and sub_field_name is not None
):
return self.get_search_field(
sub_field_name, field.fields
) # pragma: no cover
# If it's a dict, just lookup directly
if isinstance(fields, dict):
return fields.get(full_name)

# Otherwise, iterate through the tuples
for field_obj, fname in fields:
if fname == full_name:
return field_obj

return None

def build_search_query_content(self, query, config=None):
"""
Expand Down
Loading
Loading