@@ -948,6 +948,43 @@ def test_union_of_unions_of_models_with_tagged_union_invalid_variant(
948948 assert m in str (w [0 ].message )
949949
950950
951+ def test_mixed_union_models_and_other_types () -> None :
952+ s = SchemaSerializer (
953+ core_schema .union_schema (
954+ [
955+ core_schema .tagged_union_schema (
956+ discriminator = 'type_' ,
957+ choices = {
958+ 'cat' : core_schema .model_schema (
959+ cls = ModelCat ,
960+ schema = core_schema .model_fields_schema (
961+ fields = {
962+ 'type_' : core_schema .model_field (core_schema .literal_schema (['cat' ])),
963+ },
964+ ),
965+ ),
966+ 'dog' : core_schema .model_schema (
967+ cls = ModelDog ,
968+ schema = core_schema .model_fields_schema (
969+ fields = {
970+ 'type_' : core_schema .model_field (core_schema .literal_schema (['dog' ])),
971+ },
972+ ),
973+ ),
974+ },
975+ ),
976+ core_schema .str_schema (),
977+ ]
978+ )
979+ )
980+
981+ assert s .to_python (ModelCat (type_ = 'cat' ), warnings = 'error' ) == {'type_' : 'cat' }
982+ assert s .to_python (ModelDog (type_ = 'dog' ), warnings = 'error' ) == {'type_' : 'dog' }
983+ # note, this fails as ModelCat and ModelDog (discriminator warnings, etc), but the warnings
984+ # don't bubble up to this level :)
985+ assert s .to_python ('a string' , warnings = 'error' ) == 'a string'
986+
987+
951988@pytest .mark .parametrize (
952989 'input,expected' ,
953990 [
@@ -1000,3 +1037,28 @@ def test_union_of_unions_of_models_with_tagged_union_json_serialization(
10001037 )
10011038
10021039 assert s .to_json (input , warnings = 'error' ) == expected
1040+
1041+
1042+ def test_discriminated_union_ser_with_typed_dict () -> None :
1043+ v = SchemaSerializer (
1044+ core_schema .tagged_union_schema (
1045+ {
1046+ 'a' : core_schema .typed_dict_schema (
1047+ {
1048+ 'type' : core_schema .typed_dict_field (core_schema .literal_schema (['a' ])),
1049+ 'a' : core_schema .typed_dict_field (core_schema .int_schema ()),
1050+ }
1051+ ),
1052+ 'b' : core_schema .typed_dict_schema (
1053+ {
1054+ 'type' : core_schema .typed_dict_field (core_schema .literal_schema (['b' ])),
1055+ 'b' : core_schema .typed_dict_field (core_schema .str_schema ()),
1056+ }
1057+ ),
1058+ },
1059+ discriminator = 'type' ,
1060+ )
1061+ )
1062+
1063+ assert v .to_python ({'type' : 'a' , 'a' : 1 }, warnings = 'error' ) == {'type' : 'a' , 'a' : 1 }
1064+ assert v .to_python ({'type' : 'b' , 'b' : 'foo' }, warnings = 'error' ) == {'type' : 'b' , 'b' : 'foo' }
0 commit comments