55
66from __future__ import annotations
77
8+ from contextlib import suppress
89from typing import TYPE_CHECKING , ClassVar
910
1011from graphrag .config .enums import StorageType
1415from graphrag .storage .memory_pipeline_storage import MemoryPipelineStorage
1516
1617if TYPE_CHECKING :
18+ from collections .abc import Callable
19+
1720 from graphrag .storage .pipeline_storage import PipelineStorage
1821
1922
@@ -26,29 +29,73 @@ class StorageFactory:
2629 for individual enforcement of required/optional arguments.
2730 """
2831
29- storage_types : ClassVar [dict [str , type ]] = {}
32+ _storage_registry : ClassVar [dict [str , Callable [..., PipelineStorage ]]] = {}
33+ storage_types : ClassVar [dict [str , type ]] = {} # For backward compatibility
3034
3135 @classmethod
32- def register (cls , storage_type : str , storage : type ):
33- """Register a custom storage implementation."""
34- cls .storage_types [storage_type ] = storage
36+ def register (
37+ cls , storage_type : str , creator : Callable [..., PipelineStorage ]
38+ ) -> None :
39+ """Register a custom storage implementation.
40+
41+ Args:
42+ storage_type: The type identifier for the storage.
43+ creator: A callable that creates an instance of the storage.
44+ """
45+ cls ._storage_registry [storage_type ] = creator
46+
47+ # For backward compatibility with code that may access storage_types directly
48+ if (
49+ callable (creator )
50+ and hasattr (creator , "__annotations__" )
51+ and "return" in creator .__annotations__
52+ ):
53+ with suppress (TypeError , KeyError ):
54+ cls .storage_types [storage_type ] = creator .__annotations__ ["return" ]
3555
3656 @classmethod
3757 def create_storage (
3858 cls , storage_type : StorageType | str , kwargs : dict
3959 ) -> PipelineStorage :
40- """Create or get a storage object from the provided type."""
41- match storage_type :
42- case StorageType .blob :
43- return create_blob_storage (** kwargs )
44- case StorageType .cosmosdb :
45- return create_cosmosdb_storage (** kwargs )
46- case StorageType .file :
47- return create_file_storage (** kwargs )
48- case StorageType .memory :
49- return MemoryPipelineStorage ()
50- case _:
51- if storage_type in cls .storage_types :
52- return cls .storage_types [storage_type ](** kwargs )
53- msg = f"Unknown storage type: { storage_type } "
54- raise ValueError (msg )
60+ """Create a storage object from the provided type.
61+
62+ Args:
63+ storage_type: The type of storage to create.
64+ kwargs: Additional keyword arguments for the storage constructor.
65+
66+ Returns
67+ -------
68+ A PipelineStorage instance.
69+
70+ Raises
71+ ------
72+ ValueError: If the storage type is not registered.
73+ """
74+ storage_type_str = (
75+ storage_type .value
76+ if isinstance (storage_type , StorageType )
77+ else storage_type
78+ )
79+
80+ if storage_type_str not in cls ._storage_registry :
81+ msg = f"Unknown storage type: { storage_type } "
82+ raise ValueError (msg )
83+
84+ return cls ._storage_registry [storage_type_str ](** kwargs )
85+
86+ @classmethod
87+ def get_storage_types (cls ) -> list [str ]:
88+ """Get the registered storage implementations."""
89+ return list (cls ._storage_registry .keys ())
90+
91+ @classmethod
92+ def is_supported_storage_type (cls , storage_type : str ) -> bool :
93+ """Check if the given storage type is supported."""
94+ return storage_type in cls ._storage_registry
95+
96+
97+ # --- Register default implementations ---
98+ StorageFactory .register (StorageType .blob .value , create_blob_storage )
99+ StorageFactory .register (StorageType .cosmosdb .value , create_cosmosdb_storage )
100+ StorageFactory .register (StorageType .file .value , create_file_storage )
101+ StorageFactory .register (StorageType .memory .value , lambda ** _ : MemoryPipelineStorage ())
0 commit comments