2424import java .util .stream .Collectors ;
2525
2626import org .bson .Document ;
27-
2827import org .springframework .data .mapping .PersistentProperty ;
2928import org .springframework .data .mapping .context .MappingContext ;
3029import org .springframework .data .mongodb .core .convert .MongoConverter ;
4544import org .springframework .util .Assert ;
4645import org .springframework .util .ClassUtils ;
4746import org .springframework .util .CollectionUtils ;
47+ import org .springframework .util .LinkedMultiValueMap ;
4848import org .springframework .util .ObjectUtils ;
4949import org .springframework .util .StringUtils ;
5050
@@ -62,6 +62,7 @@ class MappingMongoJsonSchemaCreator implements MongoJsonSchemaCreator {
6262 private final MongoConverter converter ;
6363 private final MappingContext <MongoPersistentEntity <?>, MongoPersistentProperty > mappingContext ;
6464 private final Predicate <JsonSchemaPropertyContext > filter ;
65+ private final LinkedMultiValueMap <String , Class <?>> mergeProperties ;
6566
6667 /**
6768 * Create a new instance of {@link MappingMongoJsonSchemaCreator}.
@@ -72,23 +73,51 @@ class MappingMongoJsonSchemaCreator implements MongoJsonSchemaCreator {
7273 MappingMongoJsonSchemaCreator (MongoConverter converter ) {
7374
7475 this (converter , (MappingContext <MongoPersistentEntity <?>, MongoPersistentProperty >) converter .getMappingContext (),
75- (property ) -> true );
76+ (property ) -> true , new LinkedMultiValueMap <>() );
7677 }
7778
7879 @ SuppressWarnings ("unchecked" )
7980 MappingMongoJsonSchemaCreator (MongoConverter converter ,
8081 MappingContext <MongoPersistentEntity <?>, MongoPersistentProperty > mappingContext ,
81- Predicate <JsonSchemaPropertyContext > filter ) {
82+ Predicate <JsonSchemaPropertyContext > filter , LinkedMultiValueMap < String , Class <?>> mergeProperties ) {
8283
8384 Assert .notNull (converter , "Converter must not be null!" );
8485 this .converter = converter ;
8586 this .mappingContext = mappingContext ;
8687 this .filter = filter ;
88+ this .mergeProperties = mergeProperties ;
8789 }
8890
8991 @ Override
9092 public MongoJsonSchemaCreator filter (Predicate <JsonSchemaPropertyContext > filter ) {
91- return new MappingMongoJsonSchemaCreator (converter , mappingContext , filter );
93+ return new MappingMongoJsonSchemaCreator (converter , mappingContext , filter , mergeProperties );
94+ }
95+
96+ @ Override
97+ public PropertySpecifier specify (String path ) {
98+ return new PropertySpecifier () {
99+ @ Override
100+ public MongoJsonSchemaCreator types (Class <?>... types ) {
101+ return specifyTypesFor (path , types );
102+ }
103+ };
104+ }
105+
106+ /**
107+ * Specify additional types to be considered wehen rendering the schema for the given path.
108+ *
109+ * @param path path the path using {@literal dot '.'} notation.
110+ * @param types must not be {@literal null}.
111+ * @return new instance of {@link MongoJsonSchemaCreator}.
112+ * @since 3.4
113+ */
114+ public MongoJsonSchemaCreator specifyTypesFor (String path , Class <?>... types ) {
115+
116+ LinkedMultiValueMap <String , Class <?>> clone = mergeProperties .clone ();
117+ for (Class <?> type : types ) {
118+ clone .add (path , type );
119+ }
120+ return new MappingMongoJsonSchemaCreator (converter , mappingContext , filter , clone );
92121 }
93122
94123 /*
@@ -135,9 +164,12 @@ private List<JsonSchemaProperty> computePropertiesForEntity(List<MongoPersistent
135164
136165 List <MongoPersistentProperty > currentPath = new ArrayList <>(path );
137166
138- if (!filter .test (new PropertyContext (
139- currentPath .stream ().map (PersistentProperty ::getName ).collect (Collectors .joining ("." )), nested ))) {
140- continue ;
167+ String stringPath = currentPath .stream ().map (PersistentProperty ::getName ).collect (Collectors .joining ("." ));
168+ stringPath = StringUtils .hasText (stringPath ) ? (stringPath + "." + nested .getName ()) : nested .getName ();
169+ if (!filter .test (new PropertyContext (stringPath , nested ))) {
170+ if (!mergeProperties .containsKey (stringPath )) {
171+ continue ;
172+ }
141173 }
142174
143175 if (path .contains (nested )) { // cycle guard
@@ -155,14 +187,34 @@ private List<JsonSchemaProperty> computePropertiesForEntity(List<MongoPersistent
155187
156188 private JsonSchemaProperty computeSchemaForProperty (List <MongoPersistentProperty > path ) {
157189
190+ String stringPath = path .stream ().map (MongoPersistentProperty ::getName ).collect (Collectors .joining ("." ));
158191 MongoPersistentProperty property = CollectionUtils .lastElement (path );
159192
160193 boolean required = isRequiredProperty (property );
161194 Class <?> rawTargetType = computeTargetType (property ); // target type before conversion
162195 Class <?> targetType = converter .getTypeMapper ().getWriteTargetTypeFor (rawTargetType ); // conversion target type
163196
164- if (!isCollection (property ) && property .isEntity () && ObjectUtils .nullSafeEquals (rawTargetType , targetType )) {
165- return createObjectSchemaPropertyForEntity (path , property , required );
197+ if (!isCollection (property ) && ObjectUtils .nullSafeEquals (rawTargetType , targetType )) {
198+ if (property .isEntity () || mergeProperties .containsKey (stringPath )) {
199+ List <JsonSchemaProperty > targetProperties = new ArrayList <>();
200+
201+ if (property .isEntity ()) {
202+ targetProperties .add (createObjectSchemaPropertyForEntity (path , property , required ));
203+ }
204+ if (mergeProperties .containsKey (stringPath )) {
205+ for (Class <?> theType : mergeProperties .get (stringPath )) {
206+
207+ ObjectJsonSchemaProperty target = JsonSchemaProperty .object (property .getName ());
208+ List <JsonSchemaProperty > nestedProperties = computePropertiesForEntity (path ,
209+ mappingContext .getRequiredPersistentEntity (theType ));
210+
211+ targetProperties .add (createPotentiallyRequiredSchemaProperty (
212+ target .properties (nestedProperties .toArray (new JsonSchemaProperty [0 ])), required ));
213+ }
214+ }
215+ return targetProperties .size () == 1 ? targetProperties .iterator ().next ()
216+ : JsonSchemaProperty .combined (targetProperties );
217+ }
166218 }
167219
168220 String fieldName = computePropertyFieldName (property );
0 commit comments