|
4 | 4 | from django.db import router |
5 | 5 | from django.db.backends.base.schema import BaseDatabaseSchemaEditor |
6 | 6 | from django.db.models import Index, UniqueConstraint |
7 | | -from pymongo.encryption import ClientEncryption |
8 | 7 | from pymongo.operations import SearchIndexModel |
9 | 8 |
|
10 | 9 | from django_mongodb_backend.indexes import SearchIndex |
@@ -457,79 +456,154 @@ def wait_until_index_dropped(collection, index_name, timeout=60, interval=0.5): |
457 | 456 |
|
458 | 457 | def _create_collection(self, model): |
459 | 458 | """ |
460 | | - Create a collection for the model with the encrypted fields. If |
461 | | - provided, use the `_encrypted_fields_map` in the client's |
462 | | - `auto_encryption_opts`. Otherwise, create the encrypted fields map |
463 | | - with `_get_encrypted_fields`. |
| 459 | + Create a collection for the model. |
| 460 | + If the model has encrypted fields, build (or retrieve) the encrypted_fields schema. |
464 | 461 | """ |
465 | 462 | db = self.get_database() |
466 | 463 | db_table = model._meta.db_table |
| 464 | + |
467 | 465 | if model_has_encrypted_fields(model): |
| 466 | + # Encrypted path |
468 | 467 | client = self.connection.connection |
469 | 468 | auto_encryption_opts = getattr(client._options, "auto_encryption_opts", None) |
470 | 469 | if not auto_encryption_opts: |
471 | 470 | raise ImproperlyConfigured( |
472 | 471 | f"Encrypted fields found but DATABASES['{self.connection.alias}']['OPTIONS'] " |
473 | 472 | "is missing auto_encryption_opts." |
474 | 473 | ) |
| 474 | + |
475 | 475 | encrypted_fields_map = getattr(auto_encryption_opts, "_encrypted_fields_map", None) |
| 476 | + |
476 | 477 | if not encrypted_fields_map: |
477 | | - encrypted_fields = self._get_encrypted_fields(model, client, create_data_keys=True) |
| 478 | + encrypted_fields = self._get_encrypted_fields(model, create_data_keys=True) |
478 | 479 | else: |
479 | | - # If the encrypted fields map is provided, get the encrypted fields for the |
480 | | - # specific collection. |
481 | 480 | encrypted_fields = encrypted_fields_map.get(db_table) |
482 | | - db.create_collection(db_table, encryptedFields=encrypted_fields) |
| 481 | + |
| 482 | + if encrypted_fields and encrypted_fields.get("fields"): |
| 483 | + db.create_collection(db_table, encryptedFields=encrypted_fields) |
| 484 | + else: |
| 485 | + db.create_collection(db_table) |
| 486 | + |
483 | 487 | else: |
| 488 | + # Unencrypted path |
484 | 489 | db.create_collection(db_table) |
485 | 490 |
|
486 | | - def _get_encrypted_fields(self, model, client, create_data_keys=False): |
| 491 | + def _get_encrypted_fields( |
| 492 | + self, model, create_data_keys=False, key_alt_name=None, parent_model=None |
| 493 | + ): |
| 494 | + """ |
| 495 | + Recursively collect encryption schema data for only encrypted fields in a model. |
| 496 | + Returns None if no encrypted fields are found anywhere in the model hierarchy. |
| 497 | +
|
| 498 | + key_alt_name is the base path used for keyAltNames. |
| 499 | + parent_model is the dot-notated path inside the document for schema mapping. |
| 500 | + """ |
487 | 501 | connection = self.connection |
| 502 | + client = connection.connection |
488 | 503 | fields = model._meta.fields |
| 504 | + key_alt_name = key_alt_name or model._meta.db_table |
| 505 | + parent_model = parent_model or "" |
| 506 | + |
489 | 507 | options = client._options |
490 | | - auto_encryption_opts = options.auto_encryption_opts |
| 508 | + auto_encryption_opts = getattr(options, "auto_encryption_opts", None) |
| 509 | + |
| 510 | + key_vault_collection = None |
| 511 | + if auto_encryption_opts: |
| 512 | + key_vault_db, key_vault_coll = auto_encryption_opts._key_vault_namespace.split(".", 1) |
| 513 | + key_vault_collection = client[key_vault_db][key_vault_coll] |
| 514 | + |
491 | 515 | kms_provider = router.kms_provider(model) |
492 | | - master_key = self.connection.settings_dict.get("KMS_CREDENTIALS", {}).get(kms_provider) |
493 | | - client_encryption = ClientEncryption( |
494 | | - auto_encryption_opts._kms_providers, |
495 | | - auto_encryption_opts._key_vault_namespace, |
496 | | - client, |
497 | | - client.codec_options, |
498 | | - ) |
499 | | - key_vault_db, key_vault_coll = auto_encryption_opts._key_vault_namespace.split(".", 1) |
500 | | - key_vault_collection = client[key_vault_db][key_vault_coll] |
501 | | - db_table = model._meta.db_table |
| 516 | + master_key = connection.settings_dict.get("KMS_CREDENTIALS", {}).get(kms_provider) |
| 517 | + client_encryption = getattr(self.connection, "client_encryption", None) |
| 518 | + |
502 | 519 | field_list = [] |
| 520 | + |
503 | 521 | for field in fields: |
| 522 | + new_key_alt_name = f"{key_alt_name}.{field.column}" |
| 523 | + new_parent_model = f"{parent_model}.{field.column}" if parent_model else field.column |
| 524 | + |
| 525 | + # --- EmbeddedModelField --- |
504 | 526 | if isinstance(field, EmbeddedModelField): |
505 | | - # Recursively get encrypted fields for the embedded model. |
506 | | - self._get_encrypted_fields(field.embedded_model, client, create_data_keys) |
| 527 | + if getattr(field, "encrypted", False): |
| 528 | + # Entire sub-object encrypted |
| 529 | + if create_data_keys: |
| 530 | + if not client_encryption: |
| 531 | + raise ImproperlyConfigured("client_encryption is not configured.") |
| 532 | + data_key = client_encryption.create_data_key( |
| 533 | + kms_provider=kms_provider, |
| 534 | + master_key=master_key, |
| 535 | + key_alt_names=[new_key_alt_name], |
| 536 | + ) |
| 537 | + else: |
| 538 | + if key_vault_collection is None: |
| 539 | + raise ImproperlyConfigured( |
| 540 | + f"Encrypted field {new_key_alt_name} detected " |
| 541 | + "but no key vault configured" |
| 542 | + ) |
| 543 | + key_doc = key_vault_collection.find_one({"keyAltNames": new_key_alt_name}) |
| 544 | + if not key_doc: |
| 545 | + raise ValueError( |
| 546 | + f"No key found in keyvault for keyAltName={new_key_alt_name}. " |
| 547 | + "Run with '--create-data-keys' to create missing keys." |
| 548 | + ) |
| 549 | + data_key = key_doc["_id"] |
| 550 | + |
| 551 | + field_dict = { |
| 552 | + "bsonType": "object", |
| 553 | + "path": new_parent_model, |
| 554 | + "keyId": data_key, |
| 555 | + } |
| 556 | + if getattr(field, "queries", False): |
| 557 | + field_dict["queries"] = field.queries |
| 558 | + |
| 559 | + field_list.append(field_dict) |
| 560 | + else: |
| 561 | + # Recurse into embedded model |
| 562 | + embedded_result = self._get_encrypted_fields( |
| 563 | + field.embedded_model, |
| 564 | + create_data_keys=create_data_keys, |
| 565 | + key_alt_name=new_key_alt_name, |
| 566 | + parent_model=new_parent_model, |
| 567 | + ) |
| 568 | + if embedded_result and embedded_result.get("fields"): |
| 569 | + field_list.extend(embedded_result["fields"]) |
| 570 | + continue |
| 571 | + |
| 572 | + # --- Leaf encrypted field --- |
507 | 573 | if getattr(field, "encrypted", False): |
508 | | - key_alt_name = f"{db_table}.{field.column}" |
509 | 574 | if create_data_keys: |
| 575 | + if not client_encryption: |
| 576 | + raise ImproperlyConfigured("client_encryption is not configured.") |
510 | 577 | data_key = client_encryption.create_data_key( |
511 | 578 | kms_provider=kms_provider, |
512 | 579 | master_key=master_key, |
513 | | - key_alt_names=[key_alt_name], |
| 580 | + key_alt_names=[new_key_alt_name], |
514 | 581 | ) |
515 | 582 | else: |
516 | | - key_doc = key_vault_collection.find_one({"keyAltNames": key_alt_name}) |
| 583 | + if key_vault_collection is None: |
| 584 | + raise ImproperlyConfigured( |
| 585 | + f"Encrypted field {new_key_alt_name} detected " |
| 586 | + "but no key vault configured" |
| 587 | + ) |
| 588 | + key_doc = key_vault_collection.find_one({"keyAltNames": new_key_alt_name}) |
517 | 589 | if not key_doc: |
518 | 590 | raise ValueError( |
519 | | - f"No key found in keyvault for keyAltName={key_alt_name}. " |
520 | | - "You may need to run the management command with " |
521 | | - "'--create-data-keys' to create missing keys." |
| 591 | + f"No key found in keyvault for keyAltName={new_key_alt_name}. " |
| 592 | + "Run with '--create-data-keys' to create missing keys." |
522 | 593 | ) |
523 | 594 | data_key = key_doc["_id"] |
| 595 | + |
524 | 596 | field_dict = { |
525 | 597 | "bsonType": field.db_type(connection), |
526 | | - "path": field.column, |
| 598 | + "path": new_parent_model, |
527 | 599 | "keyId": data_key, |
528 | 600 | } |
529 | 601 | if getattr(field, "queries", False): |
530 | 602 | field_dict["queries"] = field.queries |
| 603 | + |
531 | 604 | field_list.append(field_dict) |
532 | | - return {"fields": field_list} |
| 605 | + |
| 606 | + return {"fields": field_list} if field_list else None |
533 | 607 |
|
534 | 608 |
|
535 | 609 | # GISSchemaEditor extends some SchemaEditor methods. |
|
0 commit comments