@@ -38,12 +38,6 @@ struct snp_guest_dev {
3838 struct miscdevice misc ;
3939
4040 struct snp_msg_desc * msg_desc ;
41-
42- union {
43- struct snp_report_req report ;
44- struct snp_derived_key_req derived_key ;
45- struct snp_ext_report_req ext_report ;
46- } req ;
4741};
4842
4943/*
@@ -71,7 +65,7 @@ struct snp_req_resp {
7165
7266static int get_report (struct snp_guest_dev * snp_dev , struct snp_guest_request_ioctl * arg )
7367{
74- struct snp_report_req * report_req = & snp_dev -> req . report ;
68+ struct snp_report_req * report_req __free ( kfree ) = NULL ;
7569 struct snp_msg_desc * mdesc = snp_dev -> msg_desc ;
7670 struct snp_report_resp * report_resp ;
7771 struct snp_guest_req req = {};
@@ -80,6 +74,10 @@ static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_io
8074 if (!arg -> req_data || !arg -> resp_data )
8175 return - EINVAL ;
8276
77+ report_req = kzalloc (sizeof (* report_req ), GFP_KERNEL_ACCOUNT );
78+ if (!report_req )
79+ return - ENOMEM ;
80+
8381 if (copy_from_user (report_req , (void __user * )arg -> req_data , sizeof (* report_req )))
8482 return - EFAULT ;
8583
@@ -116,7 +114,7 @@ static int get_report(struct snp_guest_dev *snp_dev, struct snp_guest_request_io
116114
117115static int get_derived_key (struct snp_guest_dev * snp_dev , struct snp_guest_request_ioctl * arg )
118116{
119- struct snp_derived_key_req * derived_key_req = & snp_dev -> req . derived_key ;
117+ struct snp_derived_key_req * derived_key_req __free ( kfree ) = NULL ;
120118 struct snp_derived_key_resp derived_key_resp = {0 };
121119 struct snp_msg_desc * mdesc = snp_dev -> msg_desc ;
122120 struct snp_guest_req req = {};
@@ -136,6 +134,10 @@ static int get_derived_key(struct snp_guest_dev *snp_dev, struct snp_guest_reque
136134 if (sizeof (buf ) < resp_len )
137135 return - ENOMEM ;
138136
137+ derived_key_req = kzalloc (sizeof (* derived_key_req ), GFP_KERNEL_ACCOUNT );
138+ if (!derived_key_req )
139+ return - ENOMEM ;
140+
139141 if (copy_from_user (derived_key_req , (void __user * )arg -> req_data ,
140142 sizeof (* derived_key_req )))
141143 return - EFAULT ;
@@ -168,16 +170,21 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
168170 struct snp_req_resp * io )
169171
170172{
171- struct snp_ext_report_req * report_req = & snp_dev -> req . ext_report ;
173+ struct snp_ext_report_req * report_req __free ( kfree ) = NULL ;
172174 struct snp_msg_desc * mdesc = snp_dev -> msg_desc ;
173175 struct snp_report_resp * report_resp ;
174176 struct snp_guest_req req = {};
175177 int ret , npages = 0 , resp_len ;
176178 sockptr_t certs_address ;
179+ struct page * page ;
177180
178181 if (sockptr_is_null (io -> req_data ) || sockptr_is_null (io -> resp_data ))
179182 return - EINVAL ;
180183
184+ report_req = kzalloc (sizeof (* report_req ), GFP_KERNEL_ACCOUNT );
185+ if (!report_req )
186+ return - ENOMEM ;
187+
181188 if (copy_from_sockptr (report_req , io -> req_data , sizeof (* report_req )))
182189 return - EFAULT ;
183190
@@ -203,8 +210,20 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
203210 * the host. If host does not supply any certs in it, then copy
204211 * zeros to indicate that certificate data was not provided.
205212 */
206- memset (mdesc -> certs_data , 0 , report_req -> certs_len );
207213 npages = report_req -> certs_len >> PAGE_SHIFT ;
214+ page = alloc_pages (GFP_KERNEL_ACCOUNT | __GFP_ZERO ,
215+ get_order (report_req -> certs_len ));
216+ if (!page )
217+ return - ENOMEM ;
218+
219+ req .certs_data = page_address (page );
220+ ret = set_memory_decrypted ((unsigned long )req .certs_data , npages );
221+ if (ret ) {
222+ pr_err ("failed to mark page shared, ret=%d\n" , ret );
223+ __free_pages (page , get_order (report_req -> certs_len ));
224+ return - EFAULT ;
225+ }
226+
208227cmd :
209228 /*
210229 * The intermediate response buffer is used while decrypting the
@@ -213,10 +232,12 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
213232 */
214233 resp_len = sizeof (report_resp -> data ) + mdesc -> ctx -> authsize ;
215234 report_resp = kzalloc (resp_len , GFP_KERNEL_ACCOUNT );
216- if (!report_resp )
217- return - ENOMEM ;
235+ if (!report_resp ) {
236+ ret = - ENOMEM ;
237+ goto e_free_data ;
238+ }
218239
219- mdesc -> input .data_npages = npages ;
240+ req . input .data_npages = npages ;
220241
221242 req .msg_version = arg -> msg_version ;
222243 req .msg_type = SNP_MSG_REPORT_REQ ;
@@ -231,7 +252,7 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
231252
232253 /* If certs length is invalid then copy the returned length */
233254 if (arg -> vmm_error == SNP_GUEST_VMM_ERR_INVALID_LEN ) {
234- report_req -> certs_len = mdesc -> input .data_npages << PAGE_SHIFT ;
255+ report_req -> certs_len = req . input .data_npages << PAGE_SHIFT ;
235256
236257 if (copy_to_sockptr (io -> req_data , report_req , sizeof (* report_req )))
237258 ret = - EFAULT ;
@@ -240,7 +261,7 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
240261 if (ret )
241262 goto e_free ;
242263
243- if (npages && copy_to_sockptr (certs_address , mdesc -> certs_data , report_req -> certs_len )) {
264+ if (npages && copy_to_sockptr (certs_address , req . certs_data , report_req -> certs_len )) {
244265 ret = - EFAULT ;
245266 goto e_free ;
246267 }
@@ -250,6 +271,13 @@ static int get_ext_report(struct snp_guest_dev *snp_dev, struct snp_guest_reques
250271
251272e_free :
252273 kfree (report_resp );
274+ e_free_data :
275+ if (npages ) {
276+ if (set_memory_encrypted ((unsigned long )req .certs_data , npages ))
277+ WARN_ONCE (ret , "failed to restore encryption mask (leak it)\n" );
278+ else
279+ __free_pages (page , get_order (report_req -> certs_len ));
280+ }
253281 return ret ;
254282}
255283
0 commit comments