diff --git a/src/oci_openai/oci_openai.py b/src/oci_openai/oci_openai.py index 650dd4c..385c0c0 100644 --- a/src/oci_openai/oci_openai.py +++ b/src/oci_openai/oci_openai.py @@ -182,12 +182,63 @@ class OciSessionAuth(HttpxOciAuth): """ def __init__( - self, config_file: str = DEFAULT_LOCATION, profile_name: str = DEFAULT_PROFILE + self, + config_file: str = DEFAULT_LOCATION, + profile_name: str = DEFAULT_PROFILE, + **kwargs: Mapping[str, Any], ): + """ + Initialize a Security Token-based OCI signer. + + Parameters + ---------- + config_file : str, optional + Path to the OCI configuration file. Defaults to `~/.oci/config`. + profile_name : str, optional + Profile name inside the OCI configuration file to use. Defaults to "DEFAULT". + **kwargs : Mapping[str, Any] + Optional keyword arguments: + - `generic_headers`: Optional[Dict[str, str]] + Headers to be used for generic requests. + Default: `["date", "(request-target)", "host"]` + - `body_headers`: Optional[Dict[str, str]] + Headers to be used for signed request bodies. + Default: `["content-length", "content-type", "x-content-sha256"]` + + Raises + ------ + oci.exceptions.ConfigFileNotFound + If the configuration file cannot be found. + KeyError + If a required key such as `"key_file"` is missing in the config. + Exception + For any other initialization errors. + """ + # Load OCI configuration and token config = oci.config.from_file(config_file, profile_name) token = self._load_token(config) - private_key = oci.signer.load_private_key_from_file(config["key_file"]) # type: ignore - self.signer = oci.auth.signers.SecurityTokenSigner(token, private_key) + + # Load the private key from config + key_path = config.get("key_file") + if not key_path: + raise KeyError( + f"Missing 'key_file' entry in OCI config profile '{profile_name}'." + ) + private_key = oci.signer.load_private_key_from_file(key_path) + + # Optional signer header customization + generic_headers = kwargs.pop("generic_headers", None) + body_headers = kwargs.pop("body_headers", None) + + additional_kwargs = {} + if generic_headers: + additional_kwargs["generic_headers"] = generic_headers + if body_headers: + additional_kwargs["body_headers"] = body_headers + + self.signer = oci.auth.signers.SecurityTokenSigner( + token, private_key, **additional_kwargs + ) def _load_token(self, config: Mapping[str, Any]) -> str: token_file = config["security_token_file"]