2424import java .util .stream .Collectors ;
2525
2626import org .bson .Document ;
27-
2827import org .springframework .data .mapping .PersistentProperty ;
2928import org .springframework .data .mapping .context .MappingContext ;
3029import org .springframework .data .mongodb .core .convert .MongoConverter ;
4544import org .springframework .util .Assert ;
4645import org .springframework .util .ClassUtils ;
4746import org .springframework .util .CollectionUtils ;
47+ import org .springframework .util .LinkedMultiValueMap ;
4848import org .springframework .util .ObjectUtils ;
4949import org .springframework .util .StringUtils ;
5050
@@ -62,6 +62,7 @@ class MappingMongoJsonSchemaCreator implements MongoJsonSchemaCreator {
6262 private final MongoConverter converter ;
6363 private final MappingContext <MongoPersistentEntity <?>, MongoPersistentProperty > mappingContext ;
6464 private final Predicate <JsonSchemaPropertyContext > filter ;
65+ private final LinkedMultiValueMap <String , Class <?>> mergeProperties ;
6566
6667 /**
6768 * Create a new instance of {@link MappingMongoJsonSchemaCreator}.
@@ -72,23 +73,51 @@ class MappingMongoJsonSchemaCreator implements MongoJsonSchemaCreator {
7273 MappingMongoJsonSchemaCreator (MongoConverter converter ) {
7374
7475 this (converter , (MappingContext <MongoPersistentEntity <?>, MongoPersistentProperty >) converter .getMappingContext (),
75- (property ) -> true );
76+ (property ) -> true , new LinkedMultiValueMap <>() );
7677 }
7778
7879 @ SuppressWarnings ("unchecked" )
7980 MappingMongoJsonSchemaCreator (MongoConverter converter ,
8081 MappingContext <MongoPersistentEntity <?>, MongoPersistentProperty > mappingContext ,
81- Predicate <JsonSchemaPropertyContext > filter ) {
82+ Predicate <JsonSchemaPropertyContext > filter , LinkedMultiValueMap < String , Class <?>> mergeProperties ) {
8283
8384 Assert .notNull (converter , "Converter must not be null!" );
8485 this .converter = converter ;
8586 this .mappingContext = mappingContext ;
8687 this .filter = filter ;
88+ this .mergeProperties = mergeProperties ;
8789 }
8890
8991 @ Override
9092 public MongoJsonSchemaCreator filter (Predicate <JsonSchemaPropertyContext > filter ) {
91- return new MappingMongoJsonSchemaCreator (converter , mappingContext , filter );
93+ return new MappingMongoJsonSchemaCreator (converter , mappingContext , filter , mergeProperties );
94+ }
95+
96+ @ Override
97+ public PropertySpecifier specify (String path ) {
98+ return new PropertySpecifier () {
99+ @ Override
100+ public MongoJsonSchemaCreator types (Class <?>... types ) {
101+ return specifyTypesFor (path , types );
102+ }
103+ };
104+ }
105+
106+ /**
107+ * Specify additional types to be considered wehen rendering the schema for the given path.
108+ *
109+ * @param path path the path using {@literal dot '.'} notation.
110+ * @param types must not be {@literal null}.
111+ * @return new instance of {@link MongoJsonSchemaCreator}.
112+ * @since 3.4
113+ */
114+ public MongoJsonSchemaCreator specifyTypesFor (String path , Class <?>... types ) {
115+
116+ LinkedMultiValueMap <String , Class <?>> clone = mergeProperties .clone ();
117+ for (Class <?> type : types ) {
118+ clone .add (path , type );
119+ }
120+ return new MappingMongoJsonSchemaCreator (converter , mappingContext , filter , clone );
92121 }
93122
94123 @ Override
@@ -131,9 +160,12 @@ private List<JsonSchemaProperty> computePropertiesForEntity(List<MongoPersistent
131160
132161 List <MongoPersistentProperty > currentPath = new ArrayList <>(path );
133162
134- if (!filter .test (new PropertyContext (
135- currentPath .stream ().map (PersistentProperty ::getName ).collect (Collectors .joining ("." )), nested ))) {
136- continue ;
163+ String stringPath = currentPath .stream ().map (PersistentProperty ::getName ).collect (Collectors .joining ("." ));
164+ stringPath = StringUtils .hasText (stringPath ) ? (stringPath + "." + nested .getName ()) : nested .getName ();
165+ if (!filter .test (new PropertyContext (stringPath , nested ))) {
166+ if (!mergeProperties .containsKey (stringPath )) {
167+ continue ;
168+ }
137169 }
138170
139171 if (path .contains (nested )) { // cycle guard
@@ -151,14 +183,34 @@ private List<JsonSchemaProperty> computePropertiesForEntity(List<MongoPersistent
151183
152184 private JsonSchemaProperty computeSchemaForProperty (List <MongoPersistentProperty > path ) {
153185
186+ String stringPath = path .stream ().map (MongoPersistentProperty ::getName ).collect (Collectors .joining ("." ));
154187 MongoPersistentProperty property = CollectionUtils .lastElement (path );
155188
156189 boolean required = isRequiredProperty (property );
157190 Class <?> rawTargetType = computeTargetType (property ); // target type before conversion
158191 Class <?> targetType = converter .getTypeMapper ().getWriteTargetTypeFor (rawTargetType ); // conversion target type
159192
160- if (!isCollection (property ) && property .isEntity () && ObjectUtils .nullSafeEquals (rawTargetType , targetType )) {
161- return createObjectSchemaPropertyForEntity (path , property , required );
193+ if (!isCollection (property ) && ObjectUtils .nullSafeEquals (rawTargetType , targetType )) {
194+ if (property .isEntity () || mergeProperties .containsKey (stringPath )) {
195+ List <JsonSchemaProperty > targetProperties = new ArrayList <>();
196+
197+ if (property .isEntity ()) {
198+ targetProperties .add (createObjectSchemaPropertyForEntity (path , property , required ));
199+ }
200+ if (mergeProperties .containsKey (stringPath )) {
201+ for (Class <?> theType : mergeProperties .get (stringPath )) {
202+
203+ ObjectJsonSchemaProperty target = JsonSchemaProperty .object (property .getName ());
204+ List <JsonSchemaProperty > nestedProperties = computePropertiesForEntity (path ,
205+ mappingContext .getRequiredPersistentEntity (theType ));
206+
207+ targetProperties .add (createPotentiallyRequiredSchemaProperty (
208+ target .properties (nestedProperties .toArray (new JsonSchemaProperty [0 ])), required ));
209+ }
210+ }
211+ return targetProperties .size () == 1 ? targetProperties .iterator ().next ()
212+ : JsonSchemaProperty .combined (targetProperties );
213+ }
162214 }
163215
164216 String fieldName = computePropertyFieldName (property );
0 commit comments