|
4 | 4 | from django.db import router |
5 | 5 | from django.db.backends.base.schema import BaseDatabaseSchemaEditor |
6 | 6 | from django.db.models import Index, UniqueConstraint |
7 | | -from pymongo.encryption import ClientEncryption |
8 | 7 | from pymongo.operations import SearchIndexModel |
9 | 8 |
|
10 | 9 | from django_mongodb_backend.indexes import SearchIndex |
@@ -457,79 +456,147 @@ def wait_until_index_dropped(collection, index_name, timeout=60, interval=0.5): |
457 | 456 |
|
458 | 457 | def _create_collection(self, model): |
459 | 458 | """ |
460 | | - Create a collection for the model with the encrypted fields. If |
461 | | - provided, use the `_encrypted_fields_map` in the client's |
462 | | - `auto_encryption_opts`. Otherwise, create the encrypted fields map |
463 | | - with `_get_encrypted_fields`. |
| 459 | + Create a collection for the model. |
| 460 | + If the model has encrypted fields, build (or retrieve) the encrypted_fields schema. |
464 | 461 | """ |
465 | 462 | db = self.get_database() |
466 | 463 | db_table = model._meta.db_table |
| 464 | + |
467 | 465 | if model_has_encrypted_fields(model): |
| 466 | + # Encrypted path |
468 | 467 | client = self.connection.connection |
469 | 468 | auto_encryption_opts = getattr(client._options, "auto_encryption_opts", None) |
470 | 469 | if not auto_encryption_opts: |
471 | 470 | raise ImproperlyConfigured( |
472 | 471 | f"Encrypted fields found but DATABASES['{self.connection.alias}']['OPTIONS'] " |
473 | 472 | "is missing auto_encryption_opts." |
474 | 473 | ) |
| 474 | + |
475 | 475 | encrypted_fields_map = getattr(auto_encryption_opts, "_encrypted_fields_map", None) |
| 476 | + |
476 | 477 | if not encrypted_fields_map: |
477 | | - encrypted_fields = self._get_encrypted_fields(model, client, create_data_keys=True) |
| 478 | + encrypted_fields = self._get_encrypted_fields(model, create_data_keys=True) |
478 | 479 | else: |
479 | | - # If the encrypted fields map is provided, get the encrypted fields for the |
480 | | - # specific collection. |
481 | 480 | encrypted_fields = encrypted_fields_map.get(db_table) |
482 | | - db.create_collection(db_table, encryptedFields=encrypted_fields) |
| 481 | + |
| 482 | + if encrypted_fields and encrypted_fields.get("fields"): |
| 483 | + db.create_collection(db_table, encryptedFields=encrypted_fields) |
| 484 | + else: |
| 485 | + db.create_collection(db_table) |
| 486 | + |
483 | 487 | else: |
| 488 | + # Unencrypted path |
484 | 489 | db.create_collection(db_table) |
485 | 490 |
|
486 | | - def _get_encrypted_fields(self, model, client, create_data_keys=False): |
| 491 | + def _get_encrypted_fields(self, model, create_data_keys=False, key_alt_name=None): |
| 492 | + """ |
| 493 | + Recursively collect encryption schema data for only encrypted fields in a model. |
| 494 | + Returns None if no encrypted fields are found anywhere in the model hierarchy. |
| 495 | +
|
| 496 | + key_alt_name is the base path for this level, typically model._meta.db_table. |
| 497 | + """ |
487 | 498 | connection = self.connection |
| 499 | + client = connection.connection |
488 | 500 | fields = model._meta.fields |
| 501 | + key_alt_name = key_alt_name or model._meta.db_table |
| 502 | + |
489 | 503 | options = client._options |
490 | | - auto_encryption_opts = options.auto_encryption_opts |
| 504 | + auto_encryption_opts = getattr(options, "auto_encryption_opts", None) |
| 505 | + |
| 506 | + key_vault_collection = None |
| 507 | + if auto_encryption_opts: |
| 508 | + key_vault_db, key_vault_coll = auto_encryption_opts._key_vault_namespace.split(".", 1) |
| 509 | + key_vault_collection = client[key_vault_db][key_vault_coll] |
| 510 | + |
491 | 511 | kms_provider = router.kms_provider(model) |
492 | | - master_key = self.connection.settings_dict.get("KMS_CREDENTIALS", {}).get(kms_provider) |
493 | | - client_encryption = ClientEncryption( |
494 | | - auto_encryption_opts._kms_providers, |
495 | | - auto_encryption_opts._key_vault_namespace, |
496 | | - client, |
497 | | - client.codec_options, |
498 | | - ) |
499 | | - key_vault_db, key_vault_coll = auto_encryption_opts._key_vault_namespace.split(".", 1) |
500 | | - key_vault_collection = client[key_vault_db][key_vault_coll] |
501 | | - db_table = model._meta.db_table |
| 512 | + master_key = connection.settings_dict.get("KMS_CREDENTIALS", {}).get(kms_provider) |
| 513 | + client_encryption = getattr(self.connection, "client_encryption", None) |
| 514 | + |
502 | 515 | field_list = [] |
| 516 | + |
503 | 517 | for field in fields: |
| 518 | + new_path = f"{key_alt_name}.{field.column}" |
| 519 | + |
| 520 | + # --- EmbeddedModelField --- |
504 | 521 | if isinstance(field, EmbeddedModelField): |
505 | | - # Recursively get encrypted fields for the embedded model. |
506 | | - self._get_encrypted_fields(field.embedded_model, client, create_data_keys) |
| 522 | + if getattr(field, "encrypted", False): |
| 523 | + # Entire sub-object is encrypted |
| 524 | + if create_data_keys: |
| 525 | + if not client_encryption: |
| 526 | + raise ImproperlyConfigured("client_encryption is not configured.") |
| 527 | + data_key = client_encryption.create_data_key( |
| 528 | + kms_provider=kms_provider, |
| 529 | + master_key=master_key, |
| 530 | + key_alt_names=[new_path], |
| 531 | + ) |
| 532 | + else: |
| 533 | + if key_vault_collection is None: |
| 534 | + raise ImproperlyConfigured( |
| 535 | + f"Encrypted field {new_path} detected but no key vault configured" |
| 536 | + ) |
| 537 | + key_doc = key_vault_collection.find_one({"keyAltNames": new_path}) |
| 538 | + if not key_doc: |
| 539 | + raise ValueError( |
| 540 | + f"No key found in keyvault for keyAltName={new_path}. " |
| 541 | + "Run with '--create-data-keys' to create missing keys." |
| 542 | + ) |
| 543 | + data_key = key_doc["_id"] |
| 544 | + |
| 545 | + field_dict = { |
| 546 | + "bsonType": "object", |
| 547 | + "path": field.column, |
| 548 | + "keyId": data_key, |
| 549 | + } |
| 550 | + if getattr(field, "queries", False): |
| 551 | + field_dict["queries"] = field.queries |
| 552 | + |
| 553 | + field_list.append(field_dict) |
| 554 | + else: |
| 555 | + # Not encrypting whole object — recurse first then |
| 556 | + # conditionally extend field list |
| 557 | + embedded_result = self._get_encrypted_fields( |
| 558 | + field.embedded_model, |
| 559 | + create_data_keys=create_data_keys, |
| 560 | + key_alt_name=new_path, |
| 561 | + ) |
| 562 | + if embedded_result and embedded_result.get("fields"): |
| 563 | + field_list.extend(embedded_result["fields"]) |
| 564 | + continue |
| 565 | + |
| 566 | + # --- Leaf encrypted field --- |
507 | 567 | if getattr(field, "encrypted", False): |
508 | | - key_alt_name = f"{db_table}.{field.column}" |
509 | 568 | if create_data_keys: |
| 569 | + if not client_encryption: |
| 570 | + raise ImproperlyConfigured("client_encryption is not configured.") |
510 | 571 | data_key = client_encryption.create_data_key( |
511 | 572 | kms_provider=kms_provider, |
512 | 573 | master_key=master_key, |
513 | | - key_alt_names=[key_alt_name], |
| 574 | + key_alt_names=[new_path], |
514 | 575 | ) |
515 | 576 | else: |
516 | | - key_doc = key_vault_collection.find_one({"keyAltNames": key_alt_name}) |
| 577 | + if key_vault_collection is None: |
| 578 | + raise ImproperlyConfigured( |
| 579 | + f"Encrypted field {new_path} detected but no key vault configured" |
| 580 | + ) |
| 581 | + key_doc = key_vault_collection.find_one({"keyAltNames": new_path}) |
517 | 582 | if not key_doc: |
518 | 583 | raise ValueError( |
519 | | - f"No key found in keyvault for keyAltName={key_alt_name}. " |
520 | | - "You may need to run the management command with " |
521 | | - "'--create-data-keys' to create missing keys." |
| 584 | + f"No key found in keyvault for keyAltName={new_path}. " |
| 585 | + "Run with '--create-data-keys' to create missing keys." |
522 | 586 | ) |
523 | 587 | data_key = key_doc["_id"] |
| 588 | + |
524 | 589 | field_dict = { |
525 | 590 | "bsonType": field.db_type(connection), |
526 | 591 | "path": field.column, |
527 | 592 | "keyId": data_key, |
528 | 593 | } |
529 | 594 | if getattr(field, "queries", False): |
530 | 595 | field_dict["queries"] = field.queries |
| 596 | + |
531 | 597 | field_list.append(field_dict) |
532 | | - return {"fields": field_list} |
| 598 | + |
| 599 | + return {"fields": field_list} if field_list else None |
533 | 600 |
|
534 | 601 |
|
535 | 602 | # GISSchemaEditor extends some SchemaEditor methods. |
|
0 commit comments