diff --git a/socketsync/core.py b/socketsync/core.py index 8365bdb..b1871a3 100644 --- a/socketsync/core.py +++ b/socketsync/core.py @@ -120,8 +120,10 @@ def set_org_vars() -> None: :return: """ log.debug("Getting Organization Configuration") - global org_id, org_slug, full_scan_path, repository_path, security_policy - org_id, org_slug = Core.get_org_id_slug() + global org_id, org_slug, org_plan, full_scan_path, repository_path, security_policy + org_id, org_slug, org_plan = Core.get_org_id_slug_plan() + # Ensure the organization is on the Enterprise plan + Core.assert_enterprise_plan() base_path = f"orgs/{org_slug}" full_scan_path = f"{base_path}/full-scans" repository_path = f"{base_path}/repos" @@ -148,20 +150,34 @@ def set_timeout(request_timeout: int): socketdev.set_timeout(timeout) @staticmethod - def get_org_id_slug() -> (str, str): + def get_org_id_slug_plan() -> (str, str, str): """ - Gets the Org ID and Org Slug for the API Token + Gets the Org ID, Org Slug, and Org Plan for the API Token :return: """ organizations = socket.org.get() orgs = organizations.get("organizations") new_org_id = None new_org_slug = None + new_org_plan = None if orgs is not None and len(orgs) == 1: for key in orgs: new_org_id = key new_org_slug = orgs[key].get("slug") - return new_org_id, new_org_slug + new_org_plan = orgs[key].get("plan") + return new_org_id, new_org_slug, new_org_plan + + @staticmethod + def assert_enterprise_plan() -> None: + """ + Validate that the current organization is on the Enterprise plan. If not, fail fast. + :param organization_id: The organization id to check + :return: None + :raises Exception: if the organization is not on the Enterprise plan + """ + is_enterprise = "enterprise" in org_plan.lower() + if not is_enterprise: + raise Exception("This script requires an Enterprise plan organization") @staticmethod def get_security_policy() -> dict: