Skip to content

Commit e926e13

Browse files
authored
Merge pull request #684 from jazzband/create_from_super
Add create_from_super method and test @JCoxwell it only took 12 years but I'm finally merging this. Thank you!
2 parents 1d1d2d9 + 6cd80a3 commit e926e13

File tree

5 files changed

+301
-3
lines changed

5 files changed

+301
-3
lines changed

docs/advanced.rst

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -268,6 +268,23 @@ Similarly, pre-V1.0 output formatting can be re-estated by using
268268
``polymorphic_showfield_old_format = True``.
269269

270270

271+
Creating Subclass Objects from Existing Superclass Objects
272+
------------------------------------------------------------
273+
274+
You can create an instance of a subclass from an existing instance of a superclass using the
275+
:meth:`~polymorphic.managers.PolymorphicManager.create_from_super` method
276+
of the subclass's manager. For example:
277+
278+
.. code-block:: python
279+
280+
super_instance = ModelA.objects.get(id=1)
281+
sub_instance = ModelB.objects.create_from_super(super_instance, field2='value2')
282+
283+
The restriction is that ``super_instance`` must be an instance of the direct superclass of
284+
``ModelB``, and any required fields of ``ModelB`` must be provided as keyword arguments. If multiple
285+
levels of subclassing are involved, you must call this method multiple times to "promote" each
286+
level.
287+
271288
.. _restrictions:
272289

273290
Restrictions & Caveats

src/polymorphic/managers.py

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
The manager class for use in the models.
33
"""
44

5-
from django.db import models
5+
from django.contrib.contenttypes.models import ContentType
6+
from django.db import DEFAULT_DB_ALIAS, models
67

78
from polymorphic.query import PolymorphicQuerySet
89

@@ -49,3 +50,45 @@ def not_instance_of(self, *args):
4950

5051
def get_real_instances(self, base_result_objects=None):
5152
return self.all().get_real_instances(base_result_objects=base_result_objects)
53+
54+
def create_from_super(self, obj, **kwargs):
55+
"""
56+
Create an instance of this manager's model class from the given instance of a
57+
parent class.
58+
59+
This is useful when "promoting" an instance down the inheritance chain.
60+
61+
:param obj: An instance of a parent class of the manager's model class.
62+
:param kwargs: Additional fields to set on the new instance.
63+
:return: The newly created instance.
64+
"""
65+
from .models import PolymorphicModel
66+
67+
# ensure we have the most derived real instance
68+
if isinstance(obj, PolymorphicModel):
69+
obj = obj.get_real_instance()
70+
71+
parent_ptr = self.model._meta.parents.get(type(obj), None)
72+
73+
if not parent_ptr:
74+
raise TypeError(
75+
f"{obj.__class__.__name__} is not a direct parent of {self.model.__name__}"
76+
)
77+
kwargs[parent_ptr.get_attname()] = obj.pk
78+
79+
# create the new base class with only fields that apply to it.
80+
ctype = ContentType.objects.db_manager(
81+
using=(obj._state.db or DEFAULT_DB_ALIAS)
82+
).get_for_model(self.model)
83+
nobj = self.model(**kwargs, polymorphic_ctype=ctype)
84+
nobj.save_base(raw=True, using=obj._state.db or DEFAULT_DB_ALIAS, force_insert=True)
85+
# force update the content type, but first we need to
86+
# retrieve a clean copy from the db to fill in the null
87+
# fields otherwise they would be overwritten.
88+
if isinstance(obj, PolymorphicModel):
89+
parent = obj.__class__.objects.using(obj._state.db or DEFAULT_DB_ALIAS).get(pk=obj.pk)
90+
parent.polymorphic_ctype = ctype
91+
parent.save()
92+
93+
nobj.refresh_from_db() # cast to cls
94+
return nobj

src/polymorphic/tests/migrations/0001_initial.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Generated by Django 4.2 on 2025-12-13 10:57
1+
# Generated by Django 4.2 on 2025-12-13 22:56
22

33
from django.conf import settings
44
from django.db import migrations, models
@@ -13,8 +13,8 @@ class Migration(migrations.Migration):
1313
initial = True
1414

1515
dependencies = [
16-
('auth', '0012_alter_user_first_name_max_length'),
1716
('contenttypes', '0002_remove_content_type_name'),
17+
('auth', '0012_alter_user_first_name_max_length'),
1818
]
1919

2020
operations = [

src/polymorphic/tests/test_multidb.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,142 @@ def run():
137137

138138
# Ensure no queries are made using the default database.
139139
self.assertNumQueries(0, run)
140+
141+
def test_create_from_super(self):
142+
# run create test 3 times because initial implementation
143+
# would fail after first success.
144+
from polymorphic.tests.models import (
145+
NormalBase,
146+
NormalExtension,
147+
PolyExtension,
148+
PolyExtChild,
149+
)
150+
151+
nb = NormalBase.objects.db_manager("secondary").create(nb_field=1)
152+
ne = NormalExtension.objects.db_manager("secondary").create(nb_field=2, ne_field="ne2")
153+
154+
with self.assertRaises(TypeError):
155+
PolyExtension.objects.db_manager("secondary").create_from_super(nb, poly_ext_field=3)
156+
157+
pe = PolyExtension.objects.db_manager("secondary").create_from_super(ne, poly_ext_field=3)
158+
159+
ne.refresh_from_db()
160+
self.assertEqual(type(ne), NormalExtension)
161+
self.assertEqual(type(pe), PolyExtension)
162+
self.assertEqual(pe.pk, ne.pk)
163+
164+
self.assertEqual(pe.nb_field, 2)
165+
self.assertEqual(pe.ne_field, "ne2")
166+
self.assertEqual(pe.poly_ext_field, 3)
167+
pe.refresh_from_db()
168+
self.assertEqual(pe.nb_field, 2)
169+
self.assertEqual(pe.ne_field, "ne2")
170+
self.assertEqual(pe.poly_ext_field, 3)
171+
172+
pc = PolyExtChild.objects.db_manager("secondary").create_from_super(
173+
pe, poly_child_field="pcf6"
174+
)
175+
176+
pe.refresh_from_db()
177+
ne.refresh_from_db()
178+
self.assertEqual(type(ne), NormalExtension)
179+
self.assertEqual(type(pe), PolyExtension)
180+
self.assertEqual(pe.pk, ne.pk)
181+
self.assertEqual(pe.pk, pc.pk)
182+
183+
self.assertEqual(pc.nb_field, 2)
184+
self.assertEqual(pc.ne_field, "ne2")
185+
self.assertEqual(pc.poly_ext_field, 3)
186+
pc.refresh_from_db()
187+
self.assertEqual(pc.nb_field, 2)
188+
self.assertEqual(pc.ne_field, "ne2")
189+
self.assertEqual(pc.poly_ext_field, 3)
190+
self.assertEqual(pc.poly_child_field, "pcf6")
191+
192+
self.assertEqual(
193+
pe.polymorphic_ctype,
194+
ContentType.objects.db_manager("secondary").get_for_model(PolyExtChild),
195+
)
196+
self.assertEqual(
197+
pc.polymorphic_ctype,
198+
ContentType.objects.db_manager("secondary").get_for_model(PolyExtChild),
199+
)
200+
201+
self.assertEqual(set(PolyExtension.objects.db_manager("secondary").all()), {pc})
202+
203+
a1 = Model2A.objects.db_manager("secondary").create(field1="A1a")
204+
a2 = Model2A.objects.db_manager("secondary").create(field1="A1b")
205+
206+
b1 = Model2B.objects.db_manager("secondary").create(field1="B1a", field2="B2a")
207+
b2 = Model2B.objects.db_manager("secondary").create(field1="B1b", field2="B2b")
208+
209+
c1 = Model2C.objects.db_manager("secondary").create(
210+
field1="C1a", field2="C2a", field3="C3a"
211+
)
212+
c2 = Model2C.objects.db_manager("secondary").create(
213+
field1="C1b", field2="C2b", field3="C3b"
214+
)
215+
216+
d1 = Model2D.objects.db_manager("secondary").create(
217+
field1="D1a", field2="D2a", field3="D3a", field4="D4a"
218+
)
219+
d2 = Model2D.objects.db_manager("secondary").create(
220+
field1="D1b", field2="D2b", field3="D3b", field4="D4b"
221+
)
222+
223+
with self.assertRaises(TypeError):
224+
Model2D.objects.db_manager("secondary").create_from_super(
225+
b1, field3="D3x", field4="D4x"
226+
)
227+
228+
b1_of_c = Model2B.objects.db_manager("secondary").non_polymorphic().get(pk=c1.pk)
229+
with self.assertRaises(TypeError):
230+
Model2C.objects.db_manager("secondary").create_from_super(b1_of_c, field3="C3x")
231+
232+
self.assertEqual(
233+
c1.polymorphic_ctype,
234+
ContentType.objects.db_manager("secondary").get_for_model(Model2C),
235+
)
236+
dfs1 = Model2D.objects.db_manager("secondary").create_from_super(b1_of_c, field4="D4x")
237+
self.assertEqual(type(dfs1), Model2D)
238+
self.assertEqual(dfs1.pk, c1.pk)
239+
self.assertEqual(dfs1.field1, "C1a")
240+
self.assertEqual(dfs1.field2, "C2a")
241+
self.assertEqual(dfs1.field3, "C3a")
242+
self.assertEqual(dfs1.field4, "D4x")
243+
self.assertEqual(
244+
dfs1.polymorphic_ctype,
245+
ContentType.objects.db_manager("secondary").get_for_model(Model2D),
246+
)
247+
c1.refresh_from_db()
248+
self.assertEqual(
249+
c1.polymorphic_ctype,
250+
ContentType.objects.db_manager("secondary").get_for_model(Model2D),
251+
)
252+
253+
self.assertEqual(
254+
b2.polymorphic_ctype,
255+
ContentType.objects.db_manager("secondary").get_for_model(Model2B),
256+
)
257+
cfs1 = Model2C.objects.db_manager("secondary").create_from_super(b2, field3="C3y")
258+
self.assertEqual(type(cfs1), Model2C)
259+
self.assertEqual(cfs1.pk, b2.pk)
260+
self.assertEqual(cfs1.field1, "B1b")
261+
self.assertEqual(cfs1.field2, "B2b")
262+
self.assertEqual(cfs1.field3, "C3y")
263+
b2.refresh_from_db()
264+
self.assertEqual(
265+
b2.polymorphic_ctype,
266+
ContentType.objects.db_manager("secondary").get_for_model(Model2C),
267+
)
268+
self.assertEqual(
269+
cfs1.polymorphic_ctype,
270+
ContentType.objects.db_manager("secondary").get_for_model(Model2C),
271+
)
272+
273+
self.assertEqual(
274+
set(Model2A.objects.db_manager("secondary").all()),
275+
{a1, a2, b1, dfs1, cfs1, c2, d1, d2},
276+
)
277+
278+
self.assertEqual(Model2A.objects.count(), 0)

src/polymorphic/tests/test_orm.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
CustomPkBase,
2525
CustomPkInherit,
2626
Enhance_Base,
27+
Enhance_Plain,
2728
Enhance_Inherit,
2829
InlineParent,
2930
InlineModelA,
@@ -1769,3 +1770,101 @@ def test_manytomany_with_through_field(self):
17691770
self.assertEqual(lake.ducks.count(), 2)
17701771
self.assertIsInstance(lake.ducks.all()[0], RubberDuck)
17711772
self.assertIsInstance(lake.ducks.all()[1], RedheadDuck)
1773+
1774+
def test_create_from_super(self):
1775+
# run create test 3 times because initial implementation
1776+
# would fail after first success.
1777+
from polymorphic.tests.models import (
1778+
NormalBase,
1779+
NormalExtension,
1780+
PolyExtension,
1781+
PolyExtChild,
1782+
)
1783+
1784+
nb = NormalBase.objects.create(nb_field=1)
1785+
ne = NormalExtension.objects.create(nb_field=2, ne_field="ne2")
1786+
1787+
with self.assertRaises(TypeError):
1788+
PolyExtension.objects.create_from_super(nb, poly_ext_field=3)
1789+
1790+
pe = PolyExtension.objects.create_from_super(ne, poly_ext_field=3)
1791+
1792+
ne.refresh_from_db()
1793+
self.assertEqual(type(ne), NormalExtension)
1794+
self.assertEqual(type(pe), PolyExtension)
1795+
self.assertEqual(pe.pk, ne.pk)
1796+
1797+
self.assertEqual(pe.nb_field, 2)
1798+
self.assertEqual(pe.ne_field, "ne2")
1799+
self.assertEqual(pe.poly_ext_field, 3)
1800+
pe.refresh_from_db()
1801+
self.assertEqual(pe.nb_field, 2)
1802+
self.assertEqual(pe.ne_field, "ne2")
1803+
self.assertEqual(pe.poly_ext_field, 3)
1804+
1805+
pc = PolyExtChild.objects.create_from_super(pe, poly_child_field="pcf6")
1806+
1807+
pe.refresh_from_db()
1808+
ne.refresh_from_db()
1809+
self.assertEqual(type(ne), NormalExtension)
1810+
self.assertEqual(type(pe), PolyExtension)
1811+
self.assertEqual(pe.pk, ne.pk)
1812+
self.assertEqual(pe.pk, pc.pk)
1813+
1814+
self.assertEqual(pc.nb_field, 2)
1815+
self.assertEqual(pc.ne_field, "ne2")
1816+
self.assertEqual(pc.poly_ext_field, 3)
1817+
pc.refresh_from_db()
1818+
self.assertEqual(pc.nb_field, 2)
1819+
self.assertEqual(pc.ne_field, "ne2")
1820+
self.assertEqual(pc.poly_ext_field, 3)
1821+
self.assertEqual(pc.poly_child_field, "pcf6")
1822+
1823+
self.assertEqual(pe.polymorphic_ctype, ContentType.objects.get_for_model(PolyExtChild))
1824+
self.assertEqual(pc.polymorphic_ctype, ContentType.objects.get_for_model(PolyExtChild))
1825+
1826+
self.assertEqual(set(PolyExtension.objects.all()), {pc})
1827+
1828+
a1 = Model2A.objects.create(field1="A1a")
1829+
a2 = Model2A.objects.create(field1="A1b")
1830+
1831+
b1 = Model2B.objects.create(field1="B1a", field2="B2a")
1832+
b2 = Model2B.objects.create(field1="B1b", field2="B2b")
1833+
1834+
c1 = Model2C.objects.create(field1="C1a", field2="C2a", field3="C3a")
1835+
c2 = Model2C.objects.create(field1="C1b", field2="C2b", field3="C3b")
1836+
1837+
d1 = Model2D.objects.create(field1="D1a", field2="D2a", field3="D3a", field4="D4a")
1838+
d2 = Model2D.objects.create(field1="D1b", field2="D2b", field3="D3b", field4="D4b")
1839+
1840+
with self.assertRaises(TypeError):
1841+
Model2D.objects.create_from_super(b1, field3="D3x", field4="D4x")
1842+
1843+
b1_of_c = Model2B.objects.non_polymorphic().get(pk=c1.pk)
1844+
with self.assertRaises(TypeError):
1845+
Model2C.objects.create_from_super(b1_of_c, field3="C3x")
1846+
1847+
self.assertEqual(c1.polymorphic_ctype, ContentType.objects.get_for_model(Model2C))
1848+
dfs1 = Model2D.objects.create_from_super(b1_of_c, field4="D4x")
1849+
self.assertEqual(type(dfs1), Model2D)
1850+
self.assertEqual(dfs1.pk, c1.pk)
1851+
self.assertEqual(dfs1.field1, "C1a")
1852+
self.assertEqual(dfs1.field2, "C2a")
1853+
self.assertEqual(dfs1.field3, "C3a")
1854+
self.assertEqual(dfs1.field4, "D4x")
1855+
self.assertEqual(dfs1.polymorphic_ctype, ContentType.objects.get_for_model(Model2D))
1856+
c1.refresh_from_db()
1857+
self.assertEqual(c1.polymorphic_ctype, ContentType.objects.get_for_model(Model2D))
1858+
1859+
self.assertEqual(b2.polymorphic_ctype, ContentType.objects.get_for_model(Model2B))
1860+
cfs1 = Model2C.objects.create_from_super(b2, field3="C3y")
1861+
self.assertEqual(type(cfs1), Model2C)
1862+
self.assertEqual(cfs1.pk, b2.pk)
1863+
self.assertEqual(cfs1.field1, "B1b")
1864+
self.assertEqual(cfs1.field2, "B2b")
1865+
self.assertEqual(cfs1.field3, "C3y")
1866+
b2.refresh_from_db()
1867+
self.assertEqual(b2.polymorphic_ctype, ContentType.objects.get_for_model(Model2C))
1868+
self.assertEqual(cfs1.polymorphic_ctype, ContentType.objects.get_for_model(Model2C))
1869+
1870+
self.assertEqual(set(Model2A.objects.all()), {a1, a2, b1, dfs1, cfs1, c2, d1, d2})

0 commit comments

Comments
 (0)