Skip to content

Commit 65d1b08

Browse files
Merge remote-tracking branch 'origin/master' into pre/v2.1
2 parents 243c9e0 + c738cae commit 65d1b08

File tree

3 files changed

+87
-12
lines changed

3 files changed

+87
-12
lines changed

.pre-commit-config.yaml

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,3 @@ repos:
4747
- types-tqdm
4848
- pandas-stubs
4949
- numpy
50-
- repo: local
51-
hooks:
52-
- id: unit-tests
53-
name: unit tests
54-
entry: pytest tests/unit/ -v --tb=short
55-
language: system
56-
pass_filenames: false
57-
always_run: true
58-
stages: [pre-commit]

src/datajoint/autopopulate.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -201,7 +201,7 @@ def _rename_attributes(table, props):
201201
self._key_source *= _rename_attributes(*q)
202202
return self._key_source
203203

204-
def make(self, key: dict[str, Any]) -> None | Generator[Any, Any, None]:
204+
def make(self, key: dict[str, Any], **kwargs) -> None | Generator[Any, Any, None]:
205205
"""
206206
Compute and insert data for one key.
207207
@@ -216,6 +216,9 @@ def make(self, key: dict[str, Any]) -> None | Generator[Any, Any, None]:
216216
----------
217217
key : dict
218218
Primary key value identifying the entity to compute.
219+
**kwargs
220+
Keyword arguments passed from ``populate(make_kwargs=...)``.
221+
These are forwarded to ``make_fetch`` for the tripartite pattern.
219222
220223
Raises
221224
------
@@ -229,7 +232,7 @@ def make(self, key: dict[str, Any]) -> None | Generator[Any, Any, None]:
229232
230233
**Tripartite make**: For long-running computations, implement:
231234
232-
- ``make_fetch(key)``: Fetch data from parent tables
235+
- ``make_fetch(key, **kwargs)``: Fetch data from parent tables
233236
- ``make_compute(key, *fetched_data)``: Compute results
234237
- ``make_insert(key, *computed_result)``: Insert results
235238
@@ -247,7 +250,7 @@ def make(self, key: dict[str, Any]) -> None | Generator[Any, Any, None]:
247250
# User has implemented `_fetch`, `_compute`, and `_insert` methods instead
248251

249252
# Step 1: Fetch data from parent tables
250-
fetched_data = self.make_fetch(key) # fetched_data is a tuple
253+
fetched_data = self.make_fetch(key, **kwargs) # fetched_data is a tuple
251254
computed_result = yield fetched_data # passed as input into make_compute
252255

253256
# Step 2: If computed result is not passed in, compute the result

tests/integration/test_autopopulate.py

Lines changed: 81 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,3 +147,84 @@ def make(self, key):
147147
self.insert1(dict(key, crop_image=dict()))
148148

149149
Crop.populate()
150+
151+
152+
def test_make_kwargs_regular(prefix, connection_test):
153+
"""Test that make_kwargs are passed to regular make method."""
154+
schema = dj.Schema(f"{prefix}_make_kwargs_regular", connection=connection_test)
155+
156+
@schema
157+
class Source(dj.Lookup):
158+
definition = """
159+
source_id: int
160+
"""
161+
contents = [(1,), (2,)]
162+
163+
@schema
164+
class Computed(dj.Computed):
165+
definition = """
166+
-> Source
167+
---
168+
multiplier: int
169+
result: int
170+
"""
171+
172+
def make(self, key, multiplier=1):
173+
self.insert1(dict(key, multiplier=multiplier, result=key["source_id"] * multiplier))
174+
175+
# Test without make_kwargs
176+
Computed.populate(Source & "source_id = 1")
177+
assert (Computed & "source_id = 1").fetch1("result") == 1
178+
179+
# Test with make_kwargs
180+
Computed.populate(Source & "source_id = 2", make_kwargs={"multiplier": 10})
181+
assert (Computed & "source_id = 2").fetch1("multiplier") == 10
182+
assert (Computed & "source_id = 2").fetch1("result") == 20
183+
184+
185+
def test_make_kwargs_tripartite(prefix, connection_test):
186+
"""Test that make_kwargs are passed to make_fetch in tripartite pattern (issue #1350)."""
187+
schema = dj.Schema(f"{prefix}_make_kwargs_tripartite", connection=connection_test)
188+
189+
@schema
190+
class Source(dj.Lookup):
191+
definition = """
192+
source_id: int
193+
---
194+
value: int
195+
"""
196+
contents = [(1, 100), (2, 200)]
197+
198+
@schema
199+
class TripartiteComputed(dj.Computed):
200+
definition = """
201+
-> Source
202+
---
203+
scale: int
204+
result: int
205+
"""
206+
207+
def make_fetch(self, key, scale=1):
208+
"""Fetch data with optional scale parameter."""
209+
value = (Source & key).fetch1("value")
210+
return (value, scale)
211+
212+
def make_compute(self, key, value, scale):
213+
"""Compute result using fetched value and scale."""
214+
return (value * scale, scale)
215+
216+
def make_insert(self, key, result, scale):
217+
"""Insert computed result."""
218+
self.insert1(dict(key, scale=scale, result=result))
219+
220+
# Test without make_kwargs (scale defaults to 1)
221+
TripartiteComputed.populate(Source & "source_id = 1")
222+
row = (TripartiteComputed & "source_id = 1").fetch1()
223+
assert row["scale"] == 1
224+
assert row["result"] == 100 # 100 * 1
225+
226+
# Test with make_kwargs (scale = 5)
227+
TripartiteComputed.populate(Source & "source_id = 2", make_kwargs={"scale": 5})
228+
row = (TripartiteComputed & "source_id = 2").fetch1()
229+
assert row["scale"] == 5
230+
assert row["result"] == 1000 # 200 * 5

0 commit comments

Comments
 (0)