@@ -137,3 +137,142 @@ def run():
137137
138138 # Ensure no queries are made using the default database.
139139 self .assertNumQueries (0 , run )
140+
141+ def test_create_from_super (self ):
142+ # run create test 3 times because initial implementation
143+ # would fail after first success.
144+ from polymorphic .tests .models import (
145+ NormalBase ,
146+ NormalExtension ,
147+ PolyExtension ,
148+ PolyExtChild ,
149+ )
150+
151+ nb = NormalBase .objects .db_manager ("secondary" ).create (nb_field = 1 )
152+ ne = NormalExtension .objects .db_manager ("secondary" ).create (nb_field = 2 , ne_field = "ne2" )
153+
154+ with self .assertRaises (TypeError ):
155+ PolyExtension .objects .db_manager ("secondary" ).create_from_super (nb , poly_ext_field = 3 )
156+
157+ pe = PolyExtension .objects .db_manager ("secondary" ).create_from_super (ne , poly_ext_field = 3 )
158+
159+ ne .refresh_from_db ()
160+ self .assertEqual (type (ne ), NormalExtension )
161+ self .assertEqual (type (pe ), PolyExtension )
162+ self .assertEqual (pe .pk , ne .pk )
163+
164+ self .assertEqual (pe .nb_field , 2 )
165+ self .assertEqual (pe .ne_field , "ne2" )
166+ self .assertEqual (pe .poly_ext_field , 3 )
167+ pe .refresh_from_db ()
168+ self .assertEqual (pe .nb_field , 2 )
169+ self .assertEqual (pe .ne_field , "ne2" )
170+ self .assertEqual (pe .poly_ext_field , 3 )
171+
172+ pc = PolyExtChild .objects .db_manager ("secondary" ).create_from_super (
173+ pe , poly_child_field = "pcf6"
174+ )
175+
176+ pe .refresh_from_db ()
177+ ne .refresh_from_db ()
178+ self .assertEqual (type (ne ), NormalExtension )
179+ self .assertEqual (type (pe ), PolyExtension )
180+ self .assertEqual (pe .pk , ne .pk )
181+ self .assertEqual (pe .pk , pc .pk )
182+
183+ self .assertEqual (pc .nb_field , 2 )
184+ self .assertEqual (pc .ne_field , "ne2" )
185+ self .assertEqual (pc .poly_ext_field , 3 )
186+ pc .refresh_from_db ()
187+ self .assertEqual (pc .nb_field , 2 )
188+ self .assertEqual (pc .ne_field , "ne2" )
189+ self .assertEqual (pc .poly_ext_field , 3 )
190+ self .assertEqual (pc .poly_child_field , "pcf6" )
191+
192+ self .assertEqual (
193+ pe .polymorphic_ctype ,
194+ ContentType .objects .db_manager ("secondary" ).get_for_model (PolyExtChild ),
195+ )
196+ self .assertEqual (
197+ pc .polymorphic_ctype ,
198+ ContentType .objects .db_manager ("secondary" ).get_for_model (PolyExtChild ),
199+ )
200+
201+ self .assertEqual (set (PolyExtension .objects .db_manager ("secondary" ).all ()), {pc })
202+
203+ a1 = Model2A .objects .db_manager ("secondary" ).create (field1 = "A1a" )
204+ a2 = Model2A .objects .db_manager ("secondary" ).create (field1 = "A1b" )
205+
206+ b1 = Model2B .objects .db_manager ("secondary" ).create (field1 = "B1a" , field2 = "B2a" )
207+ b2 = Model2B .objects .db_manager ("secondary" ).create (field1 = "B1b" , field2 = "B2b" )
208+
209+ c1 = Model2C .objects .db_manager ("secondary" ).create (
210+ field1 = "C1a" , field2 = "C2a" , field3 = "C3a"
211+ )
212+ c2 = Model2C .objects .db_manager ("secondary" ).create (
213+ field1 = "C1b" , field2 = "C2b" , field3 = "C3b"
214+ )
215+
216+ d1 = Model2D .objects .db_manager ("secondary" ).create (
217+ field1 = "D1a" , field2 = "D2a" , field3 = "D3a" , field4 = "D4a"
218+ )
219+ d2 = Model2D .objects .db_manager ("secondary" ).create (
220+ field1 = "D1b" , field2 = "D2b" , field3 = "D3b" , field4 = "D4b"
221+ )
222+
223+ with self .assertRaises (TypeError ):
224+ Model2D .objects .db_manager ("secondary" ).create_from_super (
225+ b1 , field3 = "D3x" , field4 = "D4x"
226+ )
227+
228+ b1_of_c = Model2B .objects .db_manager ("secondary" ).non_polymorphic ().get (pk = c1 .pk )
229+ with self .assertRaises (TypeError ):
230+ Model2C .objects .db_manager ("secondary" ).create_from_super (b1_of_c , field3 = "C3x" )
231+
232+ self .assertEqual (
233+ c1 .polymorphic_ctype ,
234+ ContentType .objects .db_manager ("secondary" ).get_for_model (Model2C ),
235+ )
236+ dfs1 = Model2D .objects .db_manager ("secondary" ).create_from_super (b1_of_c , field4 = "D4x" )
237+ self .assertEqual (type (dfs1 ), Model2D )
238+ self .assertEqual (dfs1 .pk , c1 .pk )
239+ self .assertEqual (dfs1 .field1 , "C1a" )
240+ self .assertEqual (dfs1 .field2 , "C2a" )
241+ self .assertEqual (dfs1 .field3 , "C3a" )
242+ self .assertEqual (dfs1 .field4 , "D4x" )
243+ self .assertEqual (
244+ dfs1 .polymorphic_ctype ,
245+ ContentType .objects .db_manager ("secondary" ).get_for_model (Model2D ),
246+ )
247+ c1 .refresh_from_db ()
248+ self .assertEqual (
249+ c1 .polymorphic_ctype ,
250+ ContentType .objects .db_manager ("secondary" ).get_for_model (Model2D ),
251+ )
252+
253+ self .assertEqual (
254+ b2 .polymorphic_ctype ,
255+ ContentType .objects .db_manager ("secondary" ).get_for_model (Model2B ),
256+ )
257+ cfs1 = Model2C .objects .db_manager ("secondary" ).create_from_super (b2 , field3 = "C3y" )
258+ self .assertEqual (type (cfs1 ), Model2C )
259+ self .assertEqual (cfs1 .pk , b2 .pk )
260+ self .assertEqual (cfs1 .field1 , "B1b" )
261+ self .assertEqual (cfs1 .field2 , "B2b" )
262+ self .assertEqual (cfs1 .field3 , "C3y" )
263+ b2 .refresh_from_db ()
264+ self .assertEqual (
265+ b2 .polymorphic_ctype ,
266+ ContentType .objects .db_manager ("secondary" ).get_for_model (Model2C ),
267+ )
268+ self .assertEqual (
269+ cfs1 .polymorphic_ctype ,
270+ ContentType .objects .db_manager ("secondary" ).get_for_model (Model2C ),
271+ )
272+
273+ self .assertEqual (
274+ set (Model2A .objects .db_manager ("secondary" ).all ()),
275+ {a1 , a2 , b1 , dfs1 , cfs1 , c2 , d1 , d2 },
276+ )
277+
278+ self .assertEqual (Model2A .objects .count (), 0 )
0 commit comments