diff --git a/docker/zenml-dev.Dockerfile b/docker/zenml-dev.Dockerfile index 2f817eb4658..e9724ca9ac3 100644 --- a/docker/zenml-dev.Dockerfile +++ b/docker/zenml-dev.Dockerfile @@ -43,10 +43,11 @@ ENV PATH="$VIRTUAL_ENV/bin:$PATH" COPY README.md pyproject.toml ./ -# We first copy the __init__.py file to allow pip install-ing the Python +# We first copy the __init__.py files to allow pip install-ing the Python # dependencies as a separate cache layer. This way, we can avoid re-installing # the dependencies when the source code changes but the dependencies don't. COPY src/zenml/__init__.py ./src/zenml/ +COPY src/zenml_cli/__init__.py ./src/zenml_cli/ # Run pip install before copying the source files to install dependencies in # the virtual environment. Also create a requirements.txt file to keep track of diff --git a/docker/zenml-server-dev.Dockerfile b/docker/zenml-server-dev.Dockerfile index d5a042989a6..c62e5e7cc81 100644 --- a/docker/zenml-server-dev.Dockerfile +++ b/docker/zenml-server-dev.Dockerfile @@ -76,10 +76,11 @@ ENV PATH="$VIRTUAL_ENV/bin:$PATH" COPY --chown=$USERNAME:$USER_GID README.md pyproject.toml ./ -# We first copy the __init__.py file to allow pip install-ing the Python +# We first copy the __init__.py files to allow pip install-ing the Python # dependencies as a separate cache layer. This way, we can avoid re-installing # the dependencies when the source code changes but the dependencies don't. COPY --chown=$USERNAME:$USER_GID src/zenml/__init__.py ./src/zenml/ +COPY --chown=$USERNAME:$USER_GID src/zenml_cli/__init__.py ./src/zenml_cli/ # Run pip install before copying the source files to install dependencies in # the virtual environment. Also create a requirements.txt file to keep track of diff --git a/docs/book/reference/environment-variables.md b/docs/book/reference/environment-variables.md index 87625658cd2..534f43b66e9 100644 --- a/docs/book/reference/environment-variables.md +++ b/docs/book/reference/environment-variables.md @@ -129,6 +129,28 @@ To set the path to the global config file, used by ZenML to manage and store the export ZENML_CONFIG_PATH=/path/to/somewhere ``` +## CLI output formatting + +### Default output format + +Set the default output format for all CLI list commands: + +```bash +export ZENML_DEFAULT_OUTPUT=json +``` + +Choose from `table` (default), `json`, `yaml`, `csv`, or `tsv`. This applies to commands like `zenml stack list`, `zenml pipeline list`, etc. + +### Terminal width override + +Override the automatic terminal width detection for table rendering: + +```bash +export ZENML_CLI_COLUMN_WIDTH=120 +``` + +This is useful when running ZenML in CI/CD environments or when you want to control table formatting regardless of your terminal size. + ## Server configuration For more information on server configuration, see the [ZenML Server documentation](../getting-started/deploying-zenml/deploy-with-docker.md#zenml-server-configuration-options) for more, especially the section entitled "ZenML server configuration options". diff --git a/docs/book/reference/global-settings.md b/docs/book/reference/global-settings.md index 4ee0d99d027..d96c280b02d 100644 --- a/docs/book/reference/global-settings.md +++ b/docs/book/reference/global-settings.md @@ -47,6 +47,10 @@ Using the default local database. ┗━━━━━━━━┷━━━━━━━━━━━━┷━━━━━━━━┷━━━━━━━━━┷━━━━━━━━━━━━━━━━┷━━━━━━━━━━━━━━┛ ``` +{% hint style="info" %} +The output can be customized with an `--output` (json, yaml, csv, tsv, table) option and a `--columns` selection. See [environment variables](environment-variables.md#cli-output-formatting) for more details. +{% endhint %} + The following is an example of the layout of the global config directory immediately after initialization: ``` diff --git a/docs/book/user-guide/best-practices/quick-wins.md b/docs/book/user-guide/best-practices/quick-wins.md index 411cd027457..dbca492159f 100644 --- a/docs/book/user-guide/best-practices/quick-wins.md +++ b/docs/book/user-guide/best-practices/quick-wins.md @@ -25,6 +25,7 @@ micro-setup (under 5 minutes) and any tips or gotchas to anticipate. | [Model Control Plane](#id-12-register-models-in-the-model-control-plane) | Track models and their lifecycle | Central hub for model lineage and governance | | [Parent Docker images](#id-13-create-a-parent-docker-image-for-faster-builds) | Pre-configure your dependencies in a base image | Faster builds and consistent environments | | [ZenML docs via MCP](#id-14-enable-ide-ai-zenml-docs-via-mcp-server) | Connect your IDE assistant to live ZenML docs | Faster, grounded answers and doc lookups while coding | +| [Export CLI data](#id-15-export-cli-data-in-multiple-formats) | Get machine-readable output from list commands | Perfect for scripting, automation, and data analysis | ## 1 Log rich metadata on every run @@ -854,3 +855,61 @@ Using the zenmldocs MCP server, show me how to register an MLflow experiment tra - For bleeding-edge features on develop, consult the repo or develop docs directly Learn more: [Access ZenML documentation via llms.txt and MCP](https://docs.zenml.io/reference/llms-txt) + +## 15 Export CLI data in multiple formats + +All `zenml list` commands support multiple output formats for scripting, CI/CD integration, and data analysis. + +```bash +# Get stack data as JSON for processing with jq +zenml stack list --output=json | jq '.items[] | select(.name=="production")' + +# Export pipeline runs to CSV for analysis +zenml pipeline runs list --output=csv > pipeline_runs.csv + +# Get deployment info as YAML for configuration management +zenml deployment list --output=yaml + +# Filter columns to see only what you need +zenml stack list --columns=id,name,orchestrator + +# Combine filtering with custom output formats +zenml pipeline list --columns=id,name,num_runs --output=json +``` + +**Available formats** +- **json** - Structured data with pagination info, perfect for programmatic processing +- **yaml** - Human-readable structured format, great for configuration +- **csv** - Comma-separated values for spreadsheets and data analysis +- **tsv** - Tab-separated values for simpler parsing +- **table** (default) - Formatted tables with colors and alignment + +**Key features** +- **Column filtering** - Use `--columns` to show only the fields you need +- **Scriptable** - Combine with tools like `jq`, `grep`, `awk` for powerful automation +- **Environment control** - Set `ZENML_DEFAULT_OUTPUT` to change the default format +- **Width control** - Override terminal width with `ZENML_CLI_COLUMN_WIDTH` for consistent formatting + +**Best practices** +- Use JSON format for robust parsing in scripts (includes pagination metadata) +- Use CSV/TSV for importing into spreadsheet tools or databases +- Use `--columns` to reduce noise and focus on relevant data +- Set default formats via environment variables in CI/CD environments + +**Example automation script** +```bash +#!/bin/bash +# Export all production stacks to a report + +export ZENML_DEFAULT_OUTPUT=json + +# Get all stacks and filter for production +zenml stack list | jq '.items[] | select(.name | contains("prod"))' > prod_stacks.json + +# Generate a summary CSV +zenml stack list --output=csv --columns=name,orchestrator,artifact_store > stack_summary.csv + +echo "Reports generated: prod_stacks.json and stack_summary.csv" +``` + +Learn more: [Environment Variables](https://docs.zenml.io/reference/environment-variables#cli-output-formatting) diff --git a/docs/book/user-guide/production-guide/understand-stacks.md b/docs/book/user-guide/production-guide/understand-stacks.md index 230b78830de..caf766f14ad 100644 --- a/docs/book/user-guide/production-guide/understand-stacks.md +++ b/docs/book/user-guide/production-guide/understand-stacks.md @@ -49,6 +49,10 @@ Stack 'default' with id '...' is owned by user default and is 'private'. ... ``` +{% hint style="info" %} +You can customize the output using `--columns` to show specific fields or `--output` to change the format (json, yaml, csv, tsv). Learn more in the [Quick Wins guide](../best-practices/quick-wins.md#id-15-export-cli-data-in-multiple-formats). +{% endhint %} + {% hint style="info" %} As you can see a stack can be **active** on your **client**. This simply means that any pipeline you run will be using the **active stack** as its environment. {% endhint %} diff --git a/pyproject.toml b/pyproject.toml index 577398452f4..8a9a044336c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,6 +2,9 @@ requires = ["uv_build >= 0.8.17, <0.9.0"] build-backend = "uv_build" +[tool.uv.build-backend] +module-name = ["zenml", "zenml_cli"] + [project] name = "zenml" version = "0.91.2" @@ -51,7 +54,7 @@ Repository = "https://github.com/zenml-io/zenml" Issues = "https://github.com/zenml-io/zenml/issues" [project.scripts] -zenml = "zenml.cli.cli:cli" +zenml = "zenml_cli:cli" [project.optional-dependencies] local = [ diff --git a/src/zenml/cli/__init__.py b/src/zenml/cli/__init__.py index c6d2e438d20..d3296187ec5 100644 --- a/src/zenml/cli/__init__.py +++ b/src/zenml/cli/__init__.py @@ -225,6 +225,45 @@ This syntax can also be combined to create more complex filters using the `or` and `and` keywords. +Output formats +-------------- + +All ``list`` commands support multiple output formats for scripting, +CI/CD integration, and data analysis. + +Use the ``--output`` (or ``-o``) option to specify the format: + +```bash +# Get stack data as JSON for processing with jq +zenml stack list --output=json | jq '.items[] | select(.name=="production")' + +# Export pipeline runs to CSV for analysis +zenml pipeline runs list --output=csv > pipeline_runs.csv + +# Get deployment info as YAML for configuration management +zenml deployment list --output=yaml + +# Filter columns to see only what you need +zenml stack list --columns=id,name,orchestrator + +# Combine filtering with custom output formats +zenml pipeline list --columns=id,name --output=json +``` + +Available formats: + +- **json** - Structured data with pagination info, ideal for programmatic use +- **yaml** - Human-readable structured format, great for configuration +- **csv** - Comma-separated values for spreadsheets and data analysis +- **tsv** - Tab-separated values for simpler parsing +- **table** (default) - Formatted tables with colors and alignment + +You can also control the default output format and table width using +environment variables: + +- ``ZENML_DEFAULT_OUTPUT`` - Set the default output format (e.g., ``json``) +- ``ZENML_CLI_COLUMN_WIDTH`` - Override terminal width for consistent formatting + Artifact Stores --------------- diff --git a/src/zenml/cli/annotator.py b/src/zenml/cli/annotator.py index af5038218d9..494609e7852 100644 --- a/src/zenml/cli/annotator.py +++ b/src/zenml/cli/annotator.py @@ -73,9 +73,8 @@ def dataset_list(annotator: "BaseAnnotator") -> None: if not dataset_names: cli_utils.warning("No datasets found.") return - cli_utils.print_list_items( - list_items=dataset_names, - column_title="DATASETS", + cli_utils.print_table( + [{"DATASETS": name} for name in sorted(dataset_names)] ) @dataset.command("stats") diff --git a/src/zenml/cli/artifact.py b/src/zenml/cli/artifact.py index b94a1261623..8334bfd48c9 100644 --- a/src/zenml/cli/artifact.py +++ b/src/zenml/cli/artifact.py @@ -19,7 +19,9 @@ from zenml.cli import utils as cli_utils from zenml.cli.cli import TagGroup, cli +from zenml.cli.utils import OutputFormat, list_options from zenml.client import Client +from zenml.console import console from zenml.enums import CliCategories from zenml.logger import get_logger from zenml.models import ArtifactFilter, ArtifactVersionFilter @@ -35,25 +37,27 @@ def artifact() -> None: """Commands for interacting with artifacts.""" -@cli_utils.list_options(ArtifactFilter) @artifact.command("list", help="List all artifacts.") -def list_artifacts(**kwargs: Any) -> None: +@list_options( + ArtifactFilter, + default_columns=["id", "name", "latest_version_name", "tags"], +) +def list_artifacts( + columns: str, output_format: OutputFormat, **kwargs: Any +) -> None: """List all artifacts. Args: + columns: Columns to display in output. + output_format: Format for output (table/json/yaml/csv/tsv). **kwargs: Keyword arguments to filter artifacts by. """ - artifacts = Client().list_artifacts(**kwargs) - - if not artifacts: - cli_utils.declare("No artifacts found.") - return + with console.status("Listing artifacts...\n"): + artifacts = Client().list_artifacts(**kwargs) - to_print = [] - for artifact in artifacts: - to_print.append(_artifact_to_print(artifact)) - - cli_utils.print_table(to_print) + cli_utils.print_page( + artifacts, columns, output_format, empty_message="No artifacts found." + ) @artifact.command("update", help="Update an artifact.") @@ -115,25 +119,30 @@ def version() -> None: """Commands for interacting with artifact versions.""" -@cli_utils.list_options(ArtifactVersionFilter) @version.command("list", help="List all artifact versions.") -def list_artifact_versions(**kwargs: Any) -> None: +@list_options( + ArtifactVersionFilter, + default_columns=["id", "artifact", "version", "type"], +) +def list_artifact_versions( + columns: str, output_format: OutputFormat, **kwargs: Any +) -> None: """List all artifact versions. Args: + columns: Columns to display in output. + output_format: Format for output (table/json/yaml/csv/tsv). **kwargs: Keyword arguments to filter artifact versions by. """ - artifact_versions = Client().list_artifact_versions(**kwargs) - - if not artifact_versions: - cli_utils.declare("No artifact versions found.") - return - - to_print = [] - for artifact_version in artifact_versions: - to_print.append(_artifact_version_to_print(artifact_version)) - - cli_utils.print_table(to_print) + with console.status("Listing artifact versions...\n"): + artifact_versions = Client().list_artifact_versions(**kwargs) + + cli_utils.print_page( + artifact_versions, + columns, + output_format, + empty_message="No artifact versions found.", + ) @version.command("describe", help="Show details about an artifact version.") diff --git a/src/zenml/cli/authorized_device.py b/src/zenml/cli/authorized_device.py index 71ef4edc5d4..1609e527785 100644 --- a/src/zenml/cli/authorized_device.py +++ b/src/zenml/cli/authorized_device.py @@ -19,7 +19,7 @@ from zenml.cli import utils as cli_utils from zenml.cli.cli import TagGroup, cli -from zenml.cli.utils import list_options +from zenml.cli.utils import OutputFormat, list_options from zenml.client import Client from zenml.console import console from zenml.enums import CliCategories @@ -59,24 +59,35 @@ def describe_authorized_device(id_or_prefix: str) -> None: @authorized_device.command( "list", help="List all authorized devices for the current user." ) -@list_options(OAuthDeviceFilter) -def list_authorized_devices(**kwargs: Any) -> None: +@list_options( + OAuthDeviceFilter, + default_columns=[ + "id", + "status", + "expires", + "hostname", + "os", + ], +) +def list_authorized_devices( + columns: str, output_format: OutputFormat, **kwargs: Any +) -> None: """List all authorized devices. Args: + columns: Columns to display in output. + output_format: Format for output (table/json/yaml/csv/tsv). **kwargs: Keyword arguments to filter authorized devices. """ with console.status("Listing authorized devices...\n"): devices = Client().list_authorized_devices(**kwargs) - if not devices.items: - cli_utils.declare("No authorized devices found for this filter.") - return - - cli_utils.print_pydantic_models( - devices, - columns=["id", "status", "ip_address", "hostname", "os"], - ) + cli_utils.print_page( + devices, + columns, + output_format, + empty_message="No authorized devices found for this filter.", + ) @authorized_device.command("lock") diff --git a/src/zenml/cli/base.py b/src/zenml/cli/base.py index d80f428b92d..72d142a1999 100644 --- a/src/zenml/cli/base.py +++ b/src/zenml/cli/base.py @@ -659,7 +659,11 @@ def info( write_yaml(file, user_info) declare(f"Wrote user debug info to file at '{file_write_path}'.") else: - cli_utils.print_user_info(user_info) + for key, value in user_info.items(): + if key in ["packages", "query_packages"] and not bool(value): + continue + + declare(f"{key.upper()}: {value}") if stack: try: diff --git a/src/zenml/cli/cli.py b/src/zenml/cli/cli.py index baa29bafbe5..7b594cfeb64 100644 --- a/src/zenml/cli/cli.py +++ b/src/zenml/cli/cli.py @@ -43,7 +43,7 @@ def __init__( commands: Optional[ Union[Dict[str, click.Command], Sequence[click.Command]] ] = None, - **kwargs: Dict[str, Any], + **kwargs: Any, ) -> None: """Initialize the Tag group. @@ -53,7 +53,7 @@ def __init__( commands: The commands of the group. kwargs: Additional keyword arguments. """ - super(TagGroup, self).__init__(name, commands, **kwargs) + super(TagGroup, self).__init__(name=name, commands=commands, **kwargs) self.tag = tag or CliCategories.OTHER_COMMANDS diff --git a/src/zenml/cli/code_repository.py b/src/zenml/cli/code_repository.py index d83ab6e75cf..28efe46ef65 100644 --- a/src/zenml/cli/code_repository.py +++ b/src/zenml/cli/code_repository.py @@ -19,7 +19,7 @@ from zenml.cli import utils as cli_utils from zenml.cli.cli import TagGroup, cli -from zenml.cli.utils import list_options +from zenml.cli.utils import OutputFormat, list_options from zenml.client import Client from zenml.code_repositories import BaseCodeRepository from zenml.config.source import Source @@ -189,24 +189,29 @@ def describe_code_repository(name_id_or_prefix: str) -> None: @code_repository.command("list", help="List all connected code repositories.") -@list_options(CodeRepositoryFilter) -def list_code_repositories(**kwargs: Any) -> None: +@list_options( + CodeRepositoryFilter, + default_columns=["id", "name", "type", "source", "project", "description"], +) +def list_code_repositories( + columns: str, output_format: OutputFormat, **kwargs: Any +) -> None: """List all connected code repositories. Args: + columns: Columns to display in output. + output_format: Format for output (table/json/yaml/csv/tsv). **kwargs: Keyword arguments to filter code repositories. """ with console.status("Listing code repositories...\n"): repos = Client().list_code_repositories(**kwargs) - if not repos.items: - cli_utils.declare("No code repositories found for this filter.") - return - - cli_utils.print_pydantic_models( - repos, - exclude_columns=["created", "updated", "user", "project"], - ) + cli_utils.print_page( + repos, + columns, + output_format, + empty_message="No code repositories found for this filter.", + ) @code_repository.command( diff --git a/src/zenml/cli/deployment.py b/src/zenml/cli/deployment.py index 3be2b2bb86b..bf2030bb8cb 100644 --- a/src/zenml/cli/deployment.py +++ b/src/zenml/cli/deployment.py @@ -21,7 +21,7 @@ from zenml.cli import utils as cli_utils from zenml.cli.cli import TagGroup, cli -from zenml.cli.utils import fetch_snapshot, list_options +from zenml.cli.utils import OutputFormat, fetch_snapshot, list_options from zenml.client import Client from zenml.console import console from zenml.deployers.exceptions import DeploymentInvalidParametersError @@ -72,11 +72,25 @@ def deployment() -> None: @deployment.command("list", help="List all registered deployments.") -@list_options(DeploymentFilter) -def list_deployments(**kwargs: Any) -> None: +@list_options( + DeploymentFilter, + default_columns=[ + "id", + "name", + "status", + "url", + "pipeline", + "stack", + ], +) +def list_deployments( + columns: str, output_format: OutputFormat, **kwargs: Any +) -> None: """List all registered deployments for the filter. Args: + columns: Columns to display in output. + output_format: Format for output (table/json/yaml/csv/tsv). **kwargs: Keyword arguments to filter deployments. """ client = Client() @@ -85,13 +99,15 @@ def list_deployments(**kwargs: Any) -> None: deployments = client.list_deployments(**kwargs) except KeyError as err: cli_utils.exception(err) - else: - if not deployments.items: - cli_utils.declare("No deployments found for this filter.") - return - - cli_utils.print_deployment_table(deployments=deployments.items) - cli_utils.print_page_info(deployments) + return + + cli_utils.print_page( + deployments, + columns, + output_format, + cli_utils.generate_deployment_row, + empty_message="No deployments found for this filter.", + ) @deployment.command("describe", help="Describe a deployment.") diff --git a/src/zenml/cli/model.py b/src/zenml/cli/model.py index 82af9513fe1..66747a0f146 100644 --- a/src/zenml/cli/model.py +++ b/src/zenml/cli/model.py @@ -19,7 +19,9 @@ from zenml.cli import utils as cli_utils from zenml.cli.cli import TagGroup, cli +from zenml.cli.utils import OutputFormat, list_options from zenml.client import Client +from zenml.console import console from zenml.enums import CliCategories, ModelStages from zenml.exceptions import EntityExistsError from zenml.logger import get_logger @@ -77,23 +79,32 @@ def model() -> None: """Interact with models and model versions in the Model Control Plane.""" -@cli_utils.list_options(ModelFilter) @model.command("list", help="List models with filter.") -def list_models(**kwargs: Any) -> None: +@list_options( + ModelFilter, + default_columns=[ + "id", + "name", + "latest_version_name", + "tags", + ], +) +def list_models( + columns: str, output_format: OutputFormat, **kwargs: Any +) -> None: """List models with filter in the Model Control Plane. Args: + columns: Columns to display in output. + output_format: Format for output (table/json/yaml/csv/tsv). **kwargs: Keyword arguments to filter models. """ - models = Client().list_models(**kwargs) + with console.status("Listing models...\n"): + models = Client().list_models(**kwargs) - if not models: - cli_utils.declare("No models found.") - return - to_print = [] - for model in models: - to_print.append(_model_to_print(model)) - cli_utils.print_table(to_print) + cli_utils.print_page( + models, columns, output_format, empty_message="No models found." + ) @model.command("register", help="Register a new model.") @@ -390,25 +401,29 @@ def version() -> None: """Interact with model versions in the Model Control Plane.""" -@cli_utils.list_options(ModelVersionFilter) @version.command("list", help="List model versions with filter.") -def list_model_versions(**kwargs: Any) -> None: +@list_options( + ModelVersionFilter, default_columns=["id", "model", "number", "stage"] +) +def list_model_versions( + columns: str, output_format: OutputFormat, **kwargs: Any +) -> None: """List model versions with filter in the Model Control Plane. Args: + columns: Columns to display in output. + output_format: Format for output (table/json/yaml/csv/tsv). **kwargs: Keyword arguments to filter models. """ - model_versions = Client().list_model_versions(**kwargs) - - if not model_versions: - cli_utils.declare("No model versions found.") - return - - to_print = [] - for model_version in model_versions: - to_print.append(_model_version_to_print(model_version)) - - cli_utils.print_table(to_print) + with console.status("Listing model versions...\n"): + model_versions = Client().list_model_versions(**kwargs) + + cli_utils.print_page( + model_versions, + columns, + output_format, + empty_message="No model versions found.", + ) @version.command("update", help="Update an existing model version stage.") @@ -567,6 +582,8 @@ def delete_model_version( def _print_artifacts_links_generic( model_name_or_id: str, + columns: str, + output_format: OutputFormat, model_version_name_or_number_or_id: Optional[str] = None, only_data_artifacts: bool = False, only_deployment_artifacts: bool = False, @@ -577,6 +594,8 @@ def _print_artifacts_links_generic( Args: model_name_or_id: The ID or name of the model containing version. + columns: Columns to display in output. + output_format: Format for output (table/json/yaml/csv/tsv). model_version_name_or_number_or_id: The name, number or ID of the model version. only_data_artifacts: If set, only print data artifacts. only_deployment_artifacts: If set, only print deployment artifacts. @@ -595,25 +614,23 @@ def _print_artifacts_links_generic( else "model artifacts" ) - links = Client().list_model_version_artifact_links( - model_version_id=model_version.id, - only_data_artifacts=only_data_artifacts, - only_deployment_artifacts=only_deployment_artifacts, - only_model_artifacts=only_model_artifacts, - **kwargs, - ) + with console.status(f"Listing {type_}...\n"): + links = Client().list_model_version_artifact_links( + model_version_id=model_version.id, + only_data_artifacts=only_data_artifacts, + only_deployment_artifacts=only_deployment_artifacts, + only_model_artifacts=only_model_artifacts, + **kwargs, + ) - if not links: + if not links.items: cli_utils.declare(f"No {type_} linked to the model version found.") return cli_utils.title( f"{type_} linked to the model version `{model_version.name}[{model_version.number}]`:" ) - cli_utils.print_pydantic_models( - links, - columns=["artifact_version", "created"], - ) + cli_utils.print_page(links, columns, output_format) @model.command( @@ -622,9 +639,11 @@ def _print_artifacts_links_generic( ) @click.argument("model_name") @click.option("--model_version", "-v", default=None) -@cli_utils.list_options(ModelVersionArtifactFilter) +@list_options(ModelVersionArtifactFilter) def list_model_version_data_artifacts( model_name: str, + columns: str, + output_format: OutputFormat, model_version: Optional[str] = None, **kwargs: Any, ) -> None: @@ -632,12 +651,16 @@ def list_model_version_data_artifacts( Args: model_name: The ID or name of the model containing version. + columns: Columns to display in output. + output_format: Format for output (table/json/yaml/csv/tsv). model_version: The name, number or ID of the model version. If not provided, the latest version is used. **kwargs: Keyword arguments to filter models. """ _print_artifacts_links_generic( model_name_or_id=model_name, + columns=columns, + output_format=output_format, model_version_name_or_number_or_id=model_version, only_data_artifacts=True, **kwargs, @@ -650,9 +673,11 @@ def list_model_version_data_artifacts( ) @click.argument("model_name") @click.option("--model_version", "-v", default=None) -@cli_utils.list_options(ModelVersionArtifactFilter) +@list_options(ModelVersionArtifactFilter) def list_model_version_model_artifacts( model_name: str, + columns: str, + output_format: OutputFormat, model_version: Optional[str] = None, **kwargs: Any, ) -> None: @@ -660,12 +685,16 @@ def list_model_version_model_artifacts( Args: model_name: The ID or name of the model containing version. + columns: Columns to display in output. + output_format: Format for output (table/json/yaml/csv/tsv). model_version: The name, number or ID of the model version. If not provided, the latest version is used. **kwargs: Keyword arguments to filter models. """ _print_artifacts_links_generic( model_name_or_id=model_name, + columns=columns, + output_format=output_format, model_version_name_or_number_or_id=model_version, only_model_artifacts=True, **kwargs, @@ -678,9 +707,11 @@ def list_model_version_model_artifacts( ) @click.argument("model_name") @click.option("--model_version", "-v", default=None) -@cli_utils.list_options(ModelVersionArtifactFilter) +@list_options(ModelVersionArtifactFilter) def list_model_version_deployment_artifacts( model_name: str, + columns: str, + output_format: OutputFormat, model_version: Optional[str] = None, **kwargs: Any, ) -> None: @@ -688,12 +719,16 @@ def list_model_version_deployment_artifacts( Args: model_name: The ID or name of the model containing version. + columns: Columns to display in output. + output_format: Format for output (table/json/yaml/csv/tsv). model_version: The name, number or ID of the model version. If not provided, the latest version is used. **kwargs: Keyword arguments to filter models. """ _print_artifacts_links_generic( model_name_or_id=model_name, + columns=columns, + output_format=output_format, model_version_name_or_number_or_id=model_version, only_deployment_artifacts=True, **kwargs, @@ -706,9 +741,11 @@ def list_model_version_deployment_artifacts( ) @click.argument("model_name") @click.option("--model_version", "-v", default=None) -@cli_utils.list_options(ModelVersionPipelineRunFilter) +@list_options(ModelVersionPipelineRunFilter) def list_model_version_pipeline_runs( model_name: str, + columns: str, + output_format: OutputFormat, model_version: Optional[str] = None, **kwargs: Any, ) -> None: @@ -716,6 +753,8 @@ def list_model_version_pipeline_runs( Args: model_name: The ID or name of the model containing version. + columns: Columns to display in output. + output_format: Format for output (table/json/yaml/csv/tsv). model_version: The name, number or ID of the model version. If not provided, the latest version is used. **kwargs: Keyword arguments to filter runs. @@ -725,16 +764,17 @@ def list_model_version_pipeline_runs( model_version_name_or_number_or_id=model_version, ) - runs = Client().list_model_version_pipeline_run_links( - model_version_id=model_version_response_model.id, - **kwargs, - ) + with console.status("Listing pipeline runs...\n"): + runs = Client().list_model_version_pipeline_run_links( + model_version_id=model_version_response_model.id, + **kwargs, + ) - if not runs: + if not runs.items: cli_utils.declare("No pipeline runs attached to model version found.") return cli_utils.title( f"Pipeline runs linked to the model version `{model_version_response_model.name}[{model_version_response_model.number}]`:" ) - cli_utils.print_pydantic_models(runs) + cli_utils.print_page(runs, columns, output_format) diff --git a/src/zenml/cli/pipeline.py b/src/zenml/cli/pipeline.py index 1a652219501..4988def264a 100644 --- a/src/zenml/cli/pipeline.py +++ b/src/zenml/cli/pipeline.py @@ -21,7 +21,7 @@ from zenml.cli import utils as cli_utils from zenml.cli.cli import TagGroup, cli -from zenml.cli.utils import fetch_snapshot, list_options +from zenml.cli.utils import OutputFormat, fetch_snapshot, list_options from zenml.client import Client from zenml.console import console from zenml.deployers.base_deployer import BaseDeployer @@ -574,25 +574,30 @@ def create_run_template( @pipeline.command("list", help="List all registered pipelines.") -@list_options(PipelineFilter) -def list_pipelines(**kwargs: Any) -> None: +@list_options( + PipelineFilter, + default_columns=["id", "name", "user", "latest_run_status", "created"], +) +def list_pipelines( + columns: str, output_format: OutputFormat, **kwargs: Any +) -> None: """List all registered pipelines. Args: + columns: Columns to display in output. + output_format: Format for output (table/json/yaml/csv/tsv). **kwargs: Keyword arguments to filter pipelines. """ client = Client() with console.status("Listing pipelines...\n"): pipelines = client.list_pipelines(**kwargs) - if not pipelines.items: - cli_utils.declare("No pipelines found for this filter.") - return - - cli_utils.print_pydantic_models( - pipelines, - exclude_columns=["id", "created", "updated", "user", "project"], - ) + cli_utils.print_page( + pipelines, + columns, + output_format, + empty_message="No pipelines found for this filter.", + ) @pipeline.command("delete") @@ -639,24 +644,26 @@ def schedule() -> None: @schedule.command("list", help="List all pipeline schedules.") -@list_options(ScheduleFilter) -def list_schedules(**kwargs: Any) -> None: +@list_options( + ScheduleFilter, + default_columns=["id", "name", "pipeline", "cron_expression"], +) +def list_schedules( + columns: str, output_format: OutputFormat, **kwargs: Any +) -> None: """List all pipeline schedules. Args: + columns: Columns to display in output. + output_format: Format for output (table/json/yaml/csv/tsv). **kwargs: Keyword arguments to filter schedules. """ - client = Client() - - schedules = client.list_schedules(**kwargs) - - if not schedules: - cli_utils.declare("No schedules found for this filter.") - return - - cli_utils.print_pydantic_models( + schedules = Client().list_schedules(**kwargs) + cli_utils.print_page( schedules, - exclude_columns=["id", "created", "updated", "user", "project"], + columns, + output_format, + empty_message="No schedules found for this filter.", ) @@ -731,11 +738,18 @@ def runs() -> None: @runs.command("list", help="List all registered pipeline runs.") -@list_options(PipelineRunFilter) -def list_pipeline_runs(**kwargs: Any) -> None: +@list_options( + PipelineRunFilter, + default_columns=["id", "run_name", "pipeline", "status", "stack", "owner"], +) +def list_pipeline_runs( + columns: str, output_format: OutputFormat, **kwargs: Any +) -> None: """List all registered pipeline runs for the filter. Args: + columns: Columns to display in output. + output_format: Format for output (table/json/yaml/csv/tsv). **kwargs: Keyword arguments to filter pipeline runs. """ client = Client() @@ -744,13 +758,15 @@ def list_pipeline_runs(**kwargs: Any) -> None: pipeline_runs = client.list_pipeline_runs(**kwargs) except KeyError as err: cli_utils.exception(err) - else: - if not pipeline_runs.items: - cli_utils.declare("No pipeline runs found for this filter.") - return + return - cli_utils.print_pipeline_runs_table(pipeline_runs=pipeline_runs.items) - cli_utils.print_page_info(pipeline_runs) + cli_utils.print_page( + pipeline_runs, + columns, + output_format, + cli_utils.generate_pipeline_run_row, + empty_message="No pipeline runs found for this filter.", + ) @runs.command("stop") @@ -882,11 +898,26 @@ def builds() -> None: @builds.command("list", help="List all pipeline builds.") -@list_options(PipelineBuildFilter) -def list_pipeline_builds(**kwargs: Any) -> None: +@list_options( + PipelineBuildFilter, + default_columns=[ + "id", + "pipeline", + "stack", + "zenml_version", + "python_version", + "checksum", + "is_local", + ], +) +def list_pipeline_builds( + columns: str, output_format: OutputFormat, **kwargs: Any +) -> None: """List all pipeline builds for the filter. Args: + columns: Columns to display in output. + output_format: Format for output (table/json/yaml/csv/tsv). **kwargs: Keyword arguments to filter pipeline builds. """ client = Client() @@ -895,22 +926,14 @@ def list_pipeline_builds(**kwargs: Any) -> None: pipeline_builds = client.list_builds(hydrate=True, **kwargs) except KeyError as err: cli_utils.exception(err) - else: - if not pipeline_builds.items: - cli_utils.declare("No pipeline builds found for this filter.") - return + return - cli_utils.print_pydantic_models( - pipeline_builds, - exclude_columns=[ - "created", - "updated", - "user", - "project", - "images", - "stack_checksum", - ], - ) + cli_utils.print_page( + pipeline_builds, + columns, + output_format, + empty_message="No pipeline builds found for this filter.", + ) @builds.command("delete") @@ -1246,49 +1269,41 @@ def deploy_snapshot( @snapshot.command("list", help="List pipeline snapshots.") -@list_options(PipelineSnapshotFilter) -def list_pipeline_snapshots(**kwargs: Any) -> None: +@list_options( + PipelineSnapshotFilter, + default_columns=[ + "id", + "name", + "pipeline", + "is_dynamic", + "runnable", + "deployable", + ], +) +def list_pipeline_snapshots( + columns: str, output_format: OutputFormat, **kwargs: Any +) -> None: """List all pipeline snapshots for the filter. Args: + columns: Columns to display in output. + output_format: Format for output (table/json/yaml/csv/tsv). **kwargs: Keyword arguments to filter pipeline snapshots. """ client = Client() try: with console.status("Listing pipeline snapshots...\n"): - pipeline_snapshots = client.list_snapshots(hydrate=True, **kwargs) + pipeline_snapshots = client.list_snapshots(**kwargs) except KeyError as err: cli_utils.exception(err) - else: - if not pipeline_snapshots.items: - cli_utils.declare("No pipeline snapshots found for this filter.") - return + return - cli_utils.print_pydantic_models( - pipeline_snapshots, - exclude_columns=[ - "created", - "updated", - "user_id", - "project_id", - "pipeline_configuration", - "step_configurations", - "client_environment", - "client_version", - "server_version", - "run_name_template", - "pipeline_version_hash", - "pipeline_spec", - "build", - "schedule", - "code_reference", - "config_schema", - "config_template", - "source_snapshot_id", - "template_id", - "code_path", - ], - ) + cli_utils.print_page( + pipeline_snapshots, + columns, + output_format, + empty_message="No pipeline snapshots found for this filter.", + ) @snapshot.command("delete", help="Delete a pipeline snapshot.") diff --git a/src/zenml/cli/project.py b/src/zenml/cli/project.py index 2512fe58f71..6cca4647fda 100644 --- a/src/zenml/cli/project.py +++ b/src/zenml/cli/project.py @@ -20,6 +20,7 @@ from zenml.cli import utils as cli_utils from zenml.cli.cli import TagGroup, cli from zenml.cli.utils import ( + OutputFormat, check_zenml_pro_project_availability, is_sorted_or_filtered, list_options, @@ -36,32 +37,50 @@ def project() -> None: @project.command("list") -@list_options(ProjectFilter) +@list_options( + ProjectFilter, default_columns=["active", "id", "name", "description"] +) @click.pass_context -def list_projects(ctx: click.Context, /, **kwargs: Any) -> None: +def list_projects( + ctx: click.Context, + /, + columns: str, + output_format: OutputFormat, + **kwargs: Any, +) -> None: """List all projects. Args: ctx: The click context object + columns: Columns to display in output. + output_format: Format for output (table/json/yaml/csv/tsv). **kwargs: Keyword arguments to filter the list of projects. """ check_zenml_pro_project_availability() client = Client() with console.status("Listing projects...\n"): - projects = client.list_projects(**kwargs) - if projects: - try: - active_project = [client.active_project] - except Exception: - active_project = [] - cli_utils.print_pydantic_models( - projects, - exclude_columns=["id", "created", "updated"], - active_models=active_project, - show_active=not is_sorted_or_filtered(ctx), - ) - else: - cli_utils.declare("No projects found for the given filter.") + projects = client.list_projects(**kwargs, hydrate=True) + + show_active = not is_sorted_or_filtered(ctx) + if show_active and projects.items: + try: + active_project_id = client.active_project.id + if active_project_id not in {p.id for p in projects.items}: + projects.items.insert(0, client.active_project) + projects.items.sort(key=lambda p: p.id != active_project_id) + except Exception: + active_project_id = None + else: + active_project_id = None + + cli_utils.print_page( + projects, + columns, + output_format, + empty_message="No projects found for the given filter.", + row_generator=cli_utils.generate_project_row, + active_id=active_project_id, + ) @project.command("register") diff --git a/src/zenml/cli/secret.py b/src/zenml/cli/secret.py index 5c5a38b74e0..6e1f7512636 100644 --- a/src/zenml/cli/secret.py +++ b/src/zenml/cli/secret.py @@ -20,6 +20,7 @@ from zenml.cli.cli import TagGroup, cli from zenml.cli.utils import ( + OutputFormat, confirmation, convert_structured_str_to_dict, declare, @@ -28,8 +29,7 @@ list_options, parse_name_and_extra_arguments, pretty_print_secret, - print_page_info, - print_table, + print_page, validate_keys, warning, ) @@ -165,33 +165,29 @@ def create_secret( @secret.command( "list", help="List all registered secrets that match the filter criteria." ) -@list_options(SecretFilter) -def list_secrets(**kwargs: Any) -> None: +@list_options(SecretFilter, default_columns=["id", "name"]) +def list_secrets( + columns: str, output_format: OutputFormat, **kwargs: Any +) -> None: """List all secrets that fulfill the filter criteria. Args: + columns: Columns to display in output. + output_format: Format for output (table/json/yaml/csv/tsv). kwargs: Keyword arguments to filter the secrets. """ client = Client() - with console.status("Listing secrets..."): + with console.status("Listing secrets...\n"): try: secrets = client.list_secrets(**kwargs) except NotImplementedError as e: error(f"Centralized secrets management is disabled: {str(e)}") - if not secrets.items: - warning("No secrets found for the given filters.") - return - secret_rows = [ - dict( - name=secret.name, - id=str(secret.id), - private=secret.private, - ) - for secret in secrets.items - ] - print_table(secret_rows) - print_page_info(secrets) + if not secrets.items: + warning("No secrets found for the given filters.") + return + + print_page(secrets, columns, output_format) @secret.command("get", help="Get a secret with a given name, prefix or id.") diff --git a/src/zenml/cli/server.py b/src/zenml/cli/server.py index 2d96eaf84fc..93868168b8c 100644 --- a/src/zenml/cli/server.py +++ b/src/zenml/cli/server.py @@ -736,16 +736,13 @@ def server_list( cli_utils.print_pydantic_models( # type: ignore[type-var] all_servers, columns=columns, - rename_columns={ - "server_name_hyperlink": "name", - "server_id_hyperlink": "ID", - "organization_hyperlink": "organization", - "dashboard_url": "dashboard URL", - "api_hyperlink": "API URL", - "auth_status": "auth status", - }, active_models=current_server, show_active=True, + column_aliases={ + "server_id_hyperlink": "server_id", + "server_name_hyperlink": "server_name", + "organization_hyperlink": "organization", + }, ) diff --git a/src/zenml/cli/service_accounts.py b/src/zenml/cli/service_accounts.py index 9147b4f09c1..33d8c261e08 100644 --- a/src/zenml/cli/service_accounts.py +++ b/src/zenml/cli/service_accounts.py @@ -19,7 +19,7 @@ from zenml.cli import utils as cli_utils from zenml.cli.cli import TagGroup, cli -from zenml.cli.utils import list_options +from zenml.cli.utils import OutputFormat, list_options from zenml.client import Client from zenml.console import console from zenml.enums import CliCategories, StoreType @@ -185,31 +185,33 @@ def describe_service_account(service_account_name_or_id: str) -> None: @service_account.command("list") -@list_options(ServiceAccountFilter) +@list_options(ServiceAccountFilter, default_columns=["id", "name"]) @click.pass_context -def list_service_accounts(ctx: click.Context, /, **kwargs: Any) -> None: +def list_service_accounts( + ctx: click.Context, + /, + columns: str, + output_format: OutputFormat, + **kwargs: Any, +) -> None: """List all users. Args: ctx: The click context object + columns: Columns to display in output. + output_format: Format for output (table/json/yaml/csv/tsv). kwargs: Keyword arguments to filter the list of users. """ client = Client() with console.status("Listing service accounts...\n"): service_accounts = client.list_service_accounts(**kwargs) - if not service_accounts: - cli_utils.declare( - "No service accounts found for the given filters." - ) - return - cli_utils.print_pydantic_models( - service_accounts, - exclude_columns=[ - "created", - "updated", - ], - ) + cli_utils.print_page( + service_accounts, + columns, + output_format, + empty_message="No service accounts found for the given filters.", + ) @service_account.command( @@ -382,14 +384,22 @@ def describe_api_key(service_account_name_or_id: str, name_or_id: str) -> None: @api_key.command("list", help="List all API keys.") -@list_options(APIKeyFilter) +@list_options(APIKeyFilter, default_columns=["id", "name"]) @click.pass_obj -def list_api_keys(service_account_name_or_id: str, /, **kwargs: Any) -> None: +def list_api_keys( + service_account_name_or_id: str, + /, + columns: str, + output_format: OutputFormat, + **kwargs: Any, +) -> None: """List all API keys. Args: service_account_name_or_id: The name or ID of the service account for which to list the API keys. + columns: Columns to display in output. + output_format: Format for output (table/json/yaml/csv/tsv). **kwargs: Keyword arguments to filter API keys. """ with console.status("Listing API keys...\n"): @@ -400,20 +410,14 @@ def list_api_keys(service_account_name_or_id: str, /, **kwargs: Any) -> None: ) except KeyError as e: cli_utils.exception(e) - - if not api_keys.items: - cli_utils.declare("No API keys found for this filter.") return - cli_utils.print_pydantic_models( - api_keys, - exclude_columns=[ - "created", - "updated", - "key", - "retain_period_minutes", - ], - ) + cli_utils.print_page( + api_keys, + columns, + output_format, + empty_message="No API keys found for this filter.", + ) @api_key.command("update", help="Update an API key.") diff --git a/src/zenml/cli/service_connectors.py b/src/zenml/cli/service_connectors.py index f4819b397a2..12d24522a92 100644 --- a/src/zenml/cli/service_connectors.py +++ b/src/zenml/cli/service_connectors.py @@ -14,6 +14,7 @@ """Service connector CLI commands.""" from datetime import datetime +from functools import partial from typing import Any, Dict, List, Optional, Union, cast from uuid import UUID @@ -22,9 +23,9 @@ from zenml.cli import utils as cli_utils from zenml.cli.cli import TagGroup, cli from zenml.cli.utils import ( + OutputFormat, is_sorted_or_filtered, list_options, - print_page_info, ) from zenml.client import Client from zenml.console import console @@ -964,7 +965,18 @@ def register_service_connector( help="""List available service connectors. """, ) -@list_options(ServiceConnectorFilter) +@list_options( + ServiceConnectorFilter, + default_columns=[ + "active", + "id", + "name", + "type", + "resource_types", + "resource_name", + "owner", + ], +) @click.option( "--label", "-l", @@ -975,12 +987,19 @@ def register_service_connector( ) @click.pass_context def list_service_connectors( - ctx: click.Context, /, labels: Optional[List[str]] = None, **kwargs: Any + ctx: click.Context, + /, + columns: str, + output_format: OutputFormat, + labels: Optional[List[str]] = None, + **kwargs: Any, ) -> None: """List all service connectors. Args: ctx: The click context object + columns: Columns to display in output. + output_format: Format for output (table/json/yaml/csv/tsv). labels: Labels to filter by. kwargs: Keyword arguments to filter the components. """ @@ -992,16 +1011,44 @@ def list_service_connectors( ) connectors = client.list_service_connectors(**kwargs) - if not connectors: - cli_utils.declare("No service connectors found for the given filters.") - return - - cli_utils.print_service_connectors_table( - client=client, - connectors=connectors.items, - show_active=not is_sorted_or_filtered(ctx), + + show_active = not is_sorted_or_filtered(ctx) + if show_active and connectors.items: + active_connectors: List["ServiceConnectorResponse"] = [] + for components in client.active_stack_model.components.values(): + for component in components: + if component.connector: + connector = component.connector + if connector.id not in [c.id for c in active_connectors]: + if isinstance(connector.connector_type, str): + connector.set_connector_type( + client.get_service_connector_type( + connector.connector_type + ) + ) + active_connectors.append(connector) + + active_connector_ids = [c.id for c in active_connectors] + + for active_connector in active_connectors: + if active_connector.id not in {c.id for c in connectors.items}: + connectors.items.append(active_connector) + + connectors.items.sort(key=lambda c: c.id not in active_connector_ids) + else: + active_connector_ids = None + + row_formatter = partial( + cli_utils.generate_connector_row, + active_connector_ids=active_connector_ids, + ) + cli_utils.print_page( + connectors, + columns, + output_format, + row_formatter, + empty_message="No service connectors found for the given filters.", ) - print_page_info(connectors) @service_connector.command( diff --git a/src/zenml/cli/stack.py b/src/zenml/cli/stack.py index 6e82a0a9d32..fc4c740cd92 100644 --- a/src/zenml/cli/stack.py +++ b/src/zenml/cli/stack.py @@ -42,12 +42,11 @@ from zenml.cli.cli import TagGroup, cli from zenml.cli.text_utils import OldSchoolMarkdownHeading from zenml.cli.utils import ( + OutputFormat, _component_display_name, is_sorted_or_filtered, list_options, print_model_url, - print_page_info, - print_stacks_table, ) from zenml.client import Client from zenml.console import console @@ -1041,28 +1040,58 @@ def rename_stack( print_model_url(get_stack_url(stack_)) -@stack.command("list") -@list_options(StackFilter) +@stack.command( + "list", help="List all stacks that fulfill the filter requirements." +) +@list_options( + StackFilter, + default_columns=[ + "active", + "id", + "name", + "user", + "artifact_store", + "orchestrator", + "deployer", + ], +) @click.pass_context -def list_stacks(ctx: click.Context, /, **kwargs: Any) -> None: +def list_stacks( + ctx: click.Context, + /, + columns: str, + output_format: OutputFormat, + **kwargs: Any, +) -> None: """List all stacks that fulfill the filter requirements. Args: - ctx: the Click context - kwargs: Keyword arguments to filter the stacks. + ctx: The Click context. + columns: Columns to display in output. + output_format: Format for output (table/json/yaml/csv/tsv). + **kwargs: Keyword arguments to filter the stacks. """ client = Client() with console.status("Listing stacks...\n"): stacks = client.list_stacks(**kwargs) - if not stacks: - cli_utils.declare("No stacks found for the given filters.") - return - print_stacks_table( - client=client, - stacks=stacks.items, - show_active=not is_sorted_or_filtered(ctx), - ) - print_page_info(stacks) + + show_active = not is_sorted_or_filtered(ctx) + if show_active and stacks.items: + active_stack_id = client.active_stack_model.id + if active_stack_id not in {s.id for s in stacks.items}: + stacks.items.insert(0, client.active_stack_model) + stacks.items.sort(key=lambda s: s.id != active_stack_id) + else: + active_stack_id = None + + cli_utils.print_page( + stacks, + columns, + output_format, + empty_message="No stacks found for the given filters.", + row_generator=cli_utils.generate_stack_row, + active_id=active_stack_id, + ) @stack.command( diff --git a/src/zenml/cli/stack_components.py b/src/zenml/cli/stack_components.py index 87e435b6d1a..690794520fa 100644 --- a/src/zenml/cli/stack_components.py +++ b/src/zenml/cli/stack_components.py @@ -28,11 +28,11 @@ from zenml.cli.model_registry import register_model_registry_subcommands from zenml.cli.served_model import register_model_deployer_subcommands from zenml.cli.utils import ( + OutputFormat, _component_display_name, is_sorted_or_filtered, list_options, print_model_url, - print_page_info, ) from zenml.client import Client from zenml.console import console @@ -152,32 +152,63 @@ def generate_stack_component_list_command( A function that can be used as a `click` command. """ - @list_options(ComponentFilter) + @list_options( + ComponentFilter, + default_columns=["active", "id", "name", "flavor", "owner"], + ) @click.pass_context def list_stack_components_command( - ctx: click.Context, /, **kwargs: Any + ctx: click.Context, + /, + columns: str, + output_format: OutputFormat, + **kwargs: Any, ) -> None: """Prints a table of stack components. Args: ctx: The click context object + columns: Columns to display in output. + output_format: Format for output (table/json/yaml/csv/tsv). kwargs: Keyword arguments to filter the components. """ client = Client() with console.status(f"Listing {component_type.plural}..."): kwargs["type"] = component_type components = client.list_stack_components(**kwargs) - if not components: - cli_utils.declare("No components found for the given filters.") - return - cli_utils.print_components_table( - client=client, - component_type=component_type, - components=components.items, - show_active=not is_sorted_or_filtered(ctx), + show_active = not is_sorted_or_filtered(ctx) + if show_active and components.items: + active_stack = client.active_stack_model + active_component = None + if component_type in active_stack.components.keys(): + active_components = active_stack.components[component_type] + active_component = ( + active_components[0] if active_components else None + ) + + if active_component is not None: + active_component_id = active_component.id + if active_component_id not in { + c.id for c in components.items + }: + components.items.insert(0, active_component) + components.items.sort( + key=lambda c: c.id != active_component_id + ) + else: + active_component_id = None + else: + active_component_id = None + + cli_utils.print_page( + components, + columns, + output_format, + empty_message="No components found for the given filters.", + row_generator=cli_utils.generate_component_row, + active_id=active_component_id, ) - print_page_info(components) return list_stack_components_command diff --git a/src/zenml/cli/tag.py b/src/zenml/cli/tag.py index a45b17293be..c0c9cc796c3 100644 --- a/src/zenml/cli/tag.py +++ b/src/zenml/cli/tag.py @@ -37,23 +37,24 @@ def tag() -> None: """Interact with tags.""" -@cli_utils.list_options(TagFilter) @tag.command("list", help="List tags with filter.") -def list_tags(**kwargs: Any) -> None: +@cli_utils.list_options( + TagFilter, + default_columns=["id", "name", "color", "exclusive", "user"], +) +def list_tags( + columns: str, output_format: cli_utils.OutputFormat, **kwargs: Any +) -> None: """List tags with filter. Args: - **kwargs: Keyword arguments to filter models. + columns: Columns to display in output. + output_format: Format for output (table/json/yaml/csv/tsv). + **kwargs: Keyword arguments to filter tags. """ tags = Client().list_tags(**kwargs) - - if not tags: - cli_utils.declare("No tags found.") - return - - cli_utils.print_pydantic_models( - tags, - exclude_columns=["created"], + cli_utils.print_page( + tags, columns, output_format, empty_message="No tags found." ) diff --git a/src/zenml/cli/user_management.py b/src/zenml/cli/user_management.py index 71abbc371ca..e13dbafc778 100644 --- a/src/zenml/cli/user_management.py +++ b/src/zenml/cli/user_management.py @@ -19,7 +19,11 @@ from zenml.cli import utils as cli_utils from zenml.cli.cli import TagGroup, cli -from zenml.cli.utils import is_sorted_or_filtered, list_options +from zenml.cli.utils import ( + OutputFormat, + is_sorted_or_filtered, + list_options, +) from zenml.client import Client from zenml.config.global_config import GlobalConfiguration from zenml.console import console @@ -77,34 +81,47 @@ def describe_user(user_name_or_id: Optional[str] = None) -> None: @user.command("list") -@list_options(UserFilter) +@list_options( + UserFilter, default_columns=["active", "id", "name", "full_name"] +) @click.pass_context -def list_users(ctx: click.Context, /, **kwargs: Any) -> None: +def list_users( + ctx: click.Context, + /, + columns: str, + output_format: OutputFormat, + **kwargs: Any, +) -> None: """List all users. Args: ctx: The click context object + columns: Columns to display in output. + output_format: Format for output (table/json/yaml/csv/tsv). kwargs: Keyword arguments to filter the list of users. """ client = Client() - with console.status("Listing stacks...\n"): + with console.status("Listing users...\n"): users = client.list_users(**kwargs) - if not users: - cli_utils.declare("No users found for the given filters.") - return - cli_utils.print_pydantic_models( - users, - exclude_columns=[ - "created", - "updated", - "email", - "email_opted_in", - "activation_token", - ], - active_models=[Client().active_user], - show_active=not is_sorted_or_filtered(ctx), - ) + # Handle active user highlighting (only if not filtered/sorted) + show_active = not is_sorted_or_filtered(ctx) + if show_active and users.items: + active_user_id = client.active_user.id + if active_user_id not in {u.id for u in users.items}: + users.items.insert(0, client.active_user) + users.items.sort(key=lambda u: u.id != active_user_id) + else: + active_user_id = None + + cli_utils.print_page( + users, + columns, + output_format, + empty_message="No users found for the given filters.", + row_generator=cli_utils.generate_user_row, + active_id=active_user_id, + ) @user.command( diff --git a/src/zenml/cli/utils.py b/src/zenml/cli/utils.py index 962c470e0a2..b5bb20d906c 100644 --- a/src/zenml/cli/utils.py +++ b/src/zenml/cli/utils.py @@ -14,7 +14,8 @@ """Utility functions for the CLI.""" import contextlib -import functools +import csv +import io import json import os import platform @@ -22,6 +23,7 @@ import shutil import subprocess import sys +from functools import partial from typing import ( IO, TYPE_CHECKING, @@ -31,6 +33,7 @@ Dict, Iterator, List, + Literal, NoReturn, Optional, Sequence, @@ -40,7 +43,9 @@ TypeVar, Union, cast, + get_args, ) +from uuid import UUID import click import yaml @@ -49,16 +54,20 @@ from rich.console import Console from rich.emoji import Emoji, NoEmoji from rich.markdown import Markdown -from rich.markup import escape +from rich.padding import Padding from rich.prompt import Confirm, Prompt from rich.style import Style +from rich.syntax import Syntax from rich.table import Table from zenml.client import Client from zenml.console import console, zenml_style_defaults from zenml.constants import ( + ENV_ZENML_CLI_COLUMN_WIDTH, + ENV_ZENML_DEFAULT_OUTPUT, FILTERING_DATETIME_FORMAT, IS_DEBUG_ENV, + handle_int_env_var, ) from zenml.deployers.utils import ( get_deployment_input_schema, @@ -84,14 +93,15 @@ Page, ServiceConnectorRequirements, StrFilter, + UserResponse, + UserScopedResponse, UUIDFilter, ) from zenml.models.v2.base.filter import FilterGenerator from zenml.services import BaseService -from zenml.stack import StackComponent from zenml.stack.flavor import Flavor from zenml.stack.stack_component import StackComponentConfig -from zenml.utils import dict_utils, secret_utils +from zenml.utils import secret_utils from zenml.utils.package_utils import requirement_installed from zenml.utils.time_utils import expires_in from zenml.utils.typing_utils import get_origin, is_union @@ -112,17 +122,25 @@ FlavorResponse, PipelineRunResponse, PipelineSnapshotResponse, + ProjectResponse, ResourceTypeModel, ServiceConnectorRequest, ServiceConnectorResourcesModel, ServiceConnectorResponse, ServiceConnectorTypeModel, StackResponse, + UserResponse, ) from zenml.stack import Stack + logger = get_logger(__name__) +AnyResponse = TypeVar( + "AnyResponse", bound=BaseIdentifiedResponse[Any, Any, Any] +) +OutputFormat = Literal["table", "json", "yaml", "csv", "tsv"] + MAX_ARGUMENT_VALUE_SIZE = 10240 @@ -183,12 +201,10 @@ def error(text: str) -> NoReturn: error_prefix = click.style("Error: ", fg="red", bold=True) error_message = click.style(text, fg="red", bold=False) - # Create a custom ClickException that bypasses Click's default "Error: " prefix class StyledClickException(click.ClickException): def show(self, file: Optional[IO[Any]] = None) -> None: if file is None: file = click.get_text_stream("stderr") - # Print our custom styled message directly without Click's prefix click.echo(self.message, file=file) raise StyledClickException(message=error_prefix + error_message) @@ -257,87 +273,58 @@ def success( console.print(text, style=style, **kwargs) -def print_markdown(text: str) -> None: - """Prints a string as markdown. - - Args: - text: Markdown string to be printed. - """ - markdown_text = Markdown(text) - console.print(markdown_text) - - -def print_markdown_with_pager(text: str) -> None: - """Prints a string as markdown with a pager. - - Args: - text: Markdown string to be printed. - """ - markdown_text = Markdown(text) - with console.pager(): - console.print(markdown_text) - - def print_table( obj: List[Dict[str, Any]], title: Optional[str] = None, caption: Optional[str] = None, - **columns: table.Column, + columns: Optional[Dict[str, Union[table.Column, str]]] = None, + column_aliases: Optional[Dict[str, str]] = None, + **kwargs: table.Column, ) -> None: - """Prints the list of dicts in a table format. - - The input object should be a List of Dicts. Each item in that list represent - a line in the Table. Each dict should have the same keys. The keys of the - dict will be used as headers of the resulting table. + """Prints a list of dicts as a table. Args: obj: A List containing dictionaries. title: Title of the table. caption: Caption of the table. - columns: Optional column configurations to be used in the table. - """ - from rich.text import Text + columns: Optional mapping of data keys to column configurations. + Values can be either a rich Column object (uses .header for display) + or a string (used directly as the header). Keys not in this mapping + will use the uppercased key as the header. + column_aliases: Optional mapping of original column names to display + names. Use this to rename columns in the table output. + **kwargs: Deprecated. Use `columns` dict instead. Kept for backward + compatibility. + """ + all_columns: Dict[str, Union[table.Column, str]] = dict(kwargs) + if columns: + all_columns.update(columns) + + data = obj + if all_columns: + data = [] + for row in obj: + new_row = {} + for k, v in row.items(): + col_config = all_columns.get(k) + if col_config is None: + header = k.upper() + elif isinstance(col_config, table.Column): + header = str(col_config.header) + else: + header = str(col_config) + new_row[header] = v + data.append(new_row) - column_keys = {key: None for dict_ in obj for key in dict_} - column_names = [columns.get(key, key.upper()) for key in column_keys] - rich_table = table.Table( - box=box.ROUNDED, - show_lines=True, - title=title, - caption=caption, - border_style="dim", + if not data: + return + + print( + _render_table( + data, title=title, caption=caption, column_aliases=column_aliases + ), + end="", ) - for col_name in column_names: - if isinstance(col_name, str): - rich_table.add_column(str(col_name), overflow="fold") - else: - rich_table.add_column( - str(col_name.header).upper(), overflow="fold" - ) - for dict_ in obj: - values = [] - for key in column_keys: - if key is None: - values.append(None) - else: - v = dict_.get(key) or " " - if isinstance(v, str) and ( - v.startswith("http://") or v.startswith("https://") - ): - # Display the URL as a hyperlink in a way that doesn't break - # the URL when it needs to be wrapped over multiple lines - value: Union[str, Text] = Text(v, style=f"link {v}") - else: - value = str(v) - # Escape text when square brackets are used, but allow - # links to be decorated as rich style links - if "[" in value and "[link=" not in value: - value = escape(value) - values.append(value) - rich_table.add_row(*values) - if len(rich_table.columns) > 1: - rich_table.columns[0].justify = "center" - console.print(rich_table) def print_pydantic_models( @@ -346,7 +333,7 @@ def print_pydantic_models( exclude_columns: Optional[List[str]] = None, active_models: Optional[List[T]] = None, show_active: bool = False, - rename_columns: Dict[str, str] = {}, + column_aliases: Optional[Dict[str, str]] = None, ) -> None: """Prints the list of Pydantic models in a table. @@ -359,7 +346,8 @@ def print_pydantic_models( active_models: Optional list of active models of the given type T. show_active: Flag to decide whether to append the active model on the top of the list. - rename_columns: Optional dictionary to rename columns. + column_aliases: Optional mapping of original column names to display + names. Use this to rename columns in the table output. """ if exclude_columns is None: exclude_columns = list() @@ -378,74 +366,16 @@ def __dictify(model: T) -> Dict[str, str]: Returns: Dict of model attributes. """ - # Explicitly defined columns take precedence over exclude columns - if not columns: - if isinstance(model, BaseIdentifiedResponse): - include_columns = ["id"] - - if "name" in type(model).model_fields: - include_columns.append("name") - - include_columns.extend( - [ - k - for k in type(model.get_body()).model_fields.keys() - if k not in exclude_columns - ] - ) - - if model.metadata is not None: - include_columns.extend( - [ - k - for k in type( - model.get_metadata() - ).model_fields.keys() - if k not in exclude_columns - ] - ) - - else: - include_columns = [ - k - for k in model.model_dump().keys() - if k not in exclude_columns - ] - else: - include_columns = columns + include_columns = _extract_model_columns( + model=model, columns=columns, exclude_columns=exclude_columns + ) items: Dict[str, Any] = {} for k in include_columns: value = getattr(model, k) - if k in rename_columns: - k = rename_columns[k] - # In case the response model contains nested `BaseResponse`s - # we want to attempt to represent them by name, if they contain - # such a field, else the id is used - if isinstance(value, BaseIdentifiedResponse): - if "name" in type(value).model_fields: - items[k] = str(getattr(value, "name")) - else: - items[k] = str(value.id) - - # If it is a list of `BaseResponse`s access each Model within - # the list and extract either name or id - elif isinstance(value, list): - for v in value: - if isinstance(v, BaseIdentifiedResponse): - if "name" in type(v).model_fields: - items.setdefault(k, []).append( - str(getattr(v, "name")) - ) - else: - items.setdefault(k, []).append(str(v.id)) - elif isinstance(value, Set) or isinstance(value, List): - items[k] = [str(v) for v in value] - else: - items[k] = str(value) + items[k] = _format_response_value(value) - # prepend an active marker if a function to mark active was passed if not active_models and not show_active: return items @@ -455,9 +385,9 @@ def __dictify(model: T) -> Dict[str, str]: if active_models is not None and show_active_column: return { marker: ( - ":point_right:" + "[green]●[/green]" if any(model.id == a.id for a in active_models) - else "" + else "-" ), **items, } @@ -477,7 +407,10 @@ def __dictify(model: T) -> Dict[str, str]: i for i in table_items if i.id not in active_ids ] - print_table([__dictify(model) for model in table_items]) + print_table( + [__dictify(model) for model in table_items], + column_aliases=column_aliases, + ) print_page_info(models) else: table_items = list(models) @@ -491,7 +424,10 @@ def __dictify(model: T) -> Dict[str, str]: i for i in table_items if i.id not in active_ids ] - print_table([__dictify(model) for model in table_items]) + print_table( + [__dictify(model) for model in table_items], + column_aliases=column_aliases, + ) def print_pydantic_model( @@ -508,83 +444,18 @@ def print_pydantic_model( exclude_columns: Optionally specify columns to exclude. columns: Optionally specify subset and order of columns to display. """ - rich_table = table.Table( - box=box.ROUNDED, - title=title, - show_lines=True, - border_style="dim", + include_columns = _extract_model_columns( + model=model, + columns=list(columns) if columns else None, + exclude_columns=list(exclude_columns) if exclude_columns else None, ) - rich_table.add_column("PROPERTY", overflow="fold") - rich_table.add_column("VALUE", overflow="fold") - - # TODO: This uses the same _dictify function up in the print_pydantic_models - # function. This 2 can be generalized. - if exclude_columns is None: - exclude_columns = set() - - if not columns: - if isinstance(model, BaseIdentifiedResponse): - include_columns = ["id"] - - if "name" in type(model).model_fields: - include_columns.append("name") - - include_columns.extend( - [ - k - for k in type(model.get_body()).model_fields.keys() - if k not in exclude_columns - ] - ) - - if model.metadata is not None: - include_columns.extend( - [ - k - for k in type(model.get_metadata()).model_fields.keys() - if k not in exclude_columns - ] - ) - - else: - include_columns = [ - k - for k in model.model_dump().keys() - if k not in exclude_columns - ] - else: - include_columns = list(columns) items: Dict[str, Any] = {} - for k in include_columns: value = getattr(model, k) - if isinstance(value, BaseIdentifiedResponse): - if "name" in type(value).model_fields: - items[k] = str(getattr(value, "name")) - else: - items[k] = str(value.id) - - # If it is a list of `BaseResponse`s access each Model within - # the list and extract either name or id - elif isinstance(value, list): - for v in value: - if isinstance(v, BaseIdentifiedResponse): - if "name" in type(v).model_fields: - items.setdefault(k, []).append(str(getattr(v, "name"))) - else: - items.setdefault(k, []).append(str(v.id)) - - items[k] = str(items[k]) - elif isinstance(value, Set) or isinstance(value, List): - items[k] = str([str(v) for v in value]) - else: - items[k] = str(value) - - for k, v in items.items(): - rich_table.add_row(str(k).upper(), v) + items[k] = _format_response_value(value) - console.print(rich_table) + _print_key_value_table(items, title=title) def format_integration_list( @@ -632,14 +503,12 @@ def print_stack_configuration(stack: "StackResponse", active: bool) -> None: title="Stack Configuration", caption=stack_caption, show_lines=True, - border_style="dim", ) rich_table.add_column("COMPONENT_TYPE", overflow="fold") rich_table.add_column("COMPONENT_NAME", overflow="fold") for component_type, components in stack.components.items(): rich_table.add_row(component_type, components[0].name) - # capitalize entries in first column rich_table.columns[0]._cells = [ component.upper() # type: ignore[union-attr] for component in rich_table.columns[0]._cells @@ -653,7 +522,6 @@ def print_stack_configuration(stack: "StackResponse", active: bool) -> None: box=box.ROUNDED, title="Labels", show_lines=True, - border_style="dim", ) rich_table.add_column("LABEL") rich_table.add_column("VALUE", overflow="fold") @@ -704,10 +572,7 @@ def print_stack_component_configuration( from the component flavor. Only needed if the component has a connector. """ - if component.user: - user_name = component.user.name - else: - user_name = "-" + user_name = _get_user_name(component.user) declare( f"{component.type.value.title()} '{component.name}' of flavor " @@ -730,7 +595,6 @@ def print_stack_component_configuration( box=box.ROUNDED, title=title_, show_lines=True, - border_style="dim", ) rich_table.add_column("COMPONENT_PROPERTY") rich_table.add_column("VALUE", overflow="fold") @@ -754,7 +618,6 @@ def print_stack_component_configuration( box=box.ROUNDED, title="Labels", show_lines=True, - border_style="dim", ) rich_table.add_column("LABEL") rich_table.add_column("VALUE", overflow="fold") @@ -771,7 +634,6 @@ def print_stack_component_configuration( box=box.ROUNDED, title="Service Connector", show_lines=True, - border_style="dim", ) rich_table.add_column("PROPERTY") rich_table.add_column("VALUE", overflow="fold") @@ -968,27 +830,6 @@ def validate_keys(key: str) -> None: error("Please provide args with a proper identifier as the key.") -def parse_unknown_component_attributes(args: List[str]) -> List[str]: - """Parse unknown options from the CLI. - - Args: - args: A list of strings from the CLI. - - Returns: - List of parsed args. - """ - warning_message = ( - "Please provide args with a proper " - "identifier as the key and the following structure: " - "--custom_attribute" - ) - - assert all(a.startswith("--") for a in args), warning_message - p_args = [a.lstrip("-") for a in args] - assert all(v.isidentifier() for v in p_args), warning_message - return p_args - - def prompt_configuration( config_schema: Dict[str, Any], show_secrets: bool = False, @@ -1176,7 +1017,6 @@ def uninstall_package(package: str, use_uv: bool = False) -> None: use_uv: Whether to use uv for package uninstallation. """ if use_uv and not requirement_installed("uv"): - # If uv is installed globally, don't run as a python module command = [] else: command = [sys.executable, "-m"] @@ -1238,26 +1078,6 @@ def get_secret_value(value: Any) -> str: print_table(stack_dicts, title=title) -def print_list_items(list_items: List[str], column_title: str) -> None: - """Prints the configuration options of a stack. - - Args: - list_items: List of items - column_title: Title of the column - """ - rich_table = table.Table( - box=box.ROUNDED, - show_lines=True, - border_style="dim", - ) - rich_table.add_column(column_title.upper(), overflow="fold") - list_items.sort() - for item in list_items: - rich_table.add_row(item) - - console.print(rich_table) - - def get_service_state_emoji(state: "ServiceState") -> str: """Get the rich emoji representing the operational state of a Service. @@ -1373,10 +1193,9 @@ def pretty_print_model_version_details( title_ = f"Properties of model `{model_version.registered_model.name}` version `{model_version.version}`" rich_table = table.Table( - box=box.ROUNDED, + box=None, title=title_, - show_lines=True, - border_style="dim", + show_lines=False, ) rich_table.add_column("MODEL VERSION PROPERTY", overflow="fold") rich_table.add_column("VALUE", overflow="fold") @@ -1406,7 +1225,6 @@ def pretty_print_model_version_details( for item in model_version_info.items(): rich_table.add_row(*[str(elem) for elem in item]) - # capitalize entries in first column rich_table.columns[0]._cells = [ component.upper() # type: ignore[union-attr] for component in rich_table.columns[0]._cells @@ -1429,7 +1247,6 @@ def print_served_model_configuration( box=box.ROUNDED, title=title_, show_lines=True, - border_style="dim", ) rich_table.add_column("MODEL SERVICE PROPERTY", overflow="fold") rich_table.add_column("VALUE", overflow="fold") @@ -1454,7 +1271,6 @@ def print_served_model_configuration( for item in sorted_items.items(): rich_table.add_row(*[str(elem) for elem in item]) - # capitalize entries in first column rich_table.columns[0]._cells = [ component.upper() # type: ignore[union-attr] for component in rich_table.columns[0]._cells @@ -1469,14 +1285,11 @@ def describe_pydantic_object(schema_json: Dict[str, Any]) -> None: schema_json: str, represents the schema of a Pydantic object, which can be obtained through BaseModelClass.schema_json() """ - # Get the schema dict - # Extract values with defaults schema_title = schema_json["title"] required = schema_json.get("required", []) description = schema_json.get("description", "") properties = schema_json.get("properties", {}) - # Pretty print the schema warning(f"Configuration class: {schema_title}\n", bold=True) if description: @@ -1525,196 +1338,15 @@ def replace_emojis(text: str) -> str: Returns: Text with expanded emojis. """ - emoji_pattern = r":(\w+):" - emojis = re.findall(emoji_pattern, text) - for emoji in emojis: + + def _replace_emoji(match: re.Match[str]) -> str: + emoji_code = match.group(1) try: - text = text.replace(f":{emoji}:", str(Emoji(emoji))) + return str(Emoji(emoji_code)) except NoEmoji: - # If the emoji text is not a valid emoji, just ignore it - pass - return text - - -def print_stacks_table( - client: "Client", - stacks: Sequence["StackResponse"], - show_active: bool = False, -) -> None: - """Print a prettified list of all stacks supplied to this method. - - Args: - client: Repository instance - stacks: List of stacks - show_active: Flag to decide whether to append the active stack on the - top of the list. - """ - stack_dicts = [] - - stacks = list(stacks) - active_stack = client.active_stack_model - if show_active: - if active_stack.id not in [s.id for s in stacks]: - stacks.append(active_stack) - - stacks = [s for s in stacks if s.id == active_stack.id] + [ - s for s in stacks if s.id != active_stack.id - ] - - active_stack_model_id = client.active_stack_model.id - for stack in stacks: - is_active = stack.id == active_stack_model_id - - if stack.user: - user_name = stack.user.name - else: - user_name = "-" - - stack_config = { - "ACTIVE": ":point_right:" if is_active else "", - "STACK NAME": stack.name, - "STACK ID": stack.id, - "OWNER": user_name, - **{ - component_type.upper(): components[0].name - for component_type, components in stack.components.items() - }, - } - stack_dicts.append(stack_config) - - print_table(stack_dicts) - - -def print_components_table( - client: "Client", - component_type: StackComponentType, - components: Sequence["ComponentResponse"], - show_active: bool = False, -) -> None: - """Prints a table with configuration options for a list of stack components. - - If a component is active (its name matches the `active_component_name`), - it will be highlighted in a separate table column. - - Args: - client: Instance of the Repository singleton - component_type: Type of stack component - components: List of stack components to print. - show_active: Flag to decide whether to append the active stack component - on the top of the list. - """ - display_name = _component_display_name(component_type, plural=True) - - if len(components) == 0: - warning(f"No {display_name} registered.") - return - - active_stack = client.active_stack_model - active_component = None - if component_type in active_stack.components.keys(): - active_components = active_stack.components[component_type] - active_component = active_components[0] if active_components else None - - components = list(components) - if show_active and active_component is not None: - if active_component.id not in [c.id for c in components]: - components.append(active_component) - - components = [c for c in components if c.id == active_component.id] + [ - c for c in components if c.id != active_component.id - ] - - configurations = [] - for component in components: - is_active = False - - if active_component is not None: - is_active = component.id == active_component.id - - component_config = { - "ACTIVE": ":point_right:" if is_active else "", - "NAME": component.name, - "COMPONENT ID": component.id, - "FLAVOR": component.flavor_name, - "OWNER": f"{component.user.name if component.user else '-'}", - } - configurations.append(component_config) - print_table(configurations) - - -def print_service_connectors_table( - client: "Client", - connectors: Sequence["ServiceConnectorResponse"], - show_active: bool = False, -) -> None: - """Prints a table with details for a list of service connectors. - - Args: - client: Instance of the Repository singleton - connectors: List of service connectors to print. - show_active: lag to decide whether to append the active connectors - on the top of the list. - """ - if len(connectors) == 0: - return + return match.group(0) - active_connectors: List["ServiceConnectorResponse"] = [] - for components in client.active_stack_model.components.values(): - for component in components: - if component.connector: - connector = component.connector - if connector.id not in [c.id for c in active_connectors]: - if isinstance(connector.connector_type, str): - # The connector embedded within the stack component - # does not include a hydrated connector type. We need - # that to print its emojis. - connector.set_connector_type( - client.get_service_connector_type( - connector.connector_type - ) - ) - active_connectors.append(connector) - - connectors = list(connectors) - if show_active: - active_ids = [c.id for c in connectors] - for active_connector in active_connectors: - if active_connector.id not in active_ids: - connectors.append(active_connector) - - connectors = [c for c in connectors if c.id in active_ids] + [ - c for c in connectors if c.id not in active_ids - ] - - configurations = [] - for connector in connectors: - is_active = connector.id in [c.id for c in active_connectors] - labels = [ - f"{label}:{value}" for label, value in connector.labels.items() - ] - resource_name = connector.resource_id or "" - - connector_config = { - "ACTIVE": ":point_right:" if is_active else "", - "NAME": connector.name, - "ID": connector.id, - "TYPE": connector.emojified_connector_type, - "RESOURCE TYPES": "\n".join(connector.emojified_resource_types), - "RESOURCE NAME": resource_name, - "OWNER": f"{connector.user.name if connector.user else '-'}", - "EXPIRES IN": ( - expires_in( - connector.expires_at, - ":name_badge: Expired!", - connector.expires_skew_tolerance, - ) - if connector.expires_at - else "" - ), - "LABELS": "\n".join(labels), - } - configurations.append(connector_config) - print_table(configurations) + return re.sub(r":(\w+):", _replace_emoji, text) def print_service_connector_resource_table( @@ -1727,13 +1359,32 @@ def print_service_connector_resource_table( resources: List of service connector resources to print. show_resources_only: If True, only the resources will be printed. """ + + def _truncate_error(error_msg: str, max_length: int = 100) -> str: + """Truncate long error messages for better readability. + + Args: + error_msg: The error message to truncate. + max_length: The maximum length of the error message. + + Returns: + The truncated error message. + """ + if len(error_msg) <= max_length: + return error_msg + truncated = error_msg[:max_length].rsplit(" ", 1)[0] + return f"{truncated}... (use --verbose for full error)" + resource_table = [] - for resource_model in resources: + errors_found = False + + for i, resource_model in enumerate(resources): printed_connector = False resource_row: Dict[str, Any] = {} if resource_model.error: # Global error + errors_found = True if not show_resources_only: resource_row = { "CONNECTOR ID": str(resource_model.id), @@ -1745,10 +1396,28 @@ def print_service_connector_resource_table( "RESOURCE TYPE": "\n".join( resource_model.get_emojified_resource_types() ), - "RESOURCE NAMES": f":collision: error: {resource_model.error}", + "RESOURCE NAMES": f":collision: {_truncate_error(resource_model.error)}", } ) resource_table.append(resource_row) + if i < len(resources) - 1: + if not show_resources_only: + resource_table.append( + { + "CONNECTOR ID": "", + "CONNECTOR NAME": "", + "CONNECTOR TYPE": "", + "RESOURCE TYPE": "", + "RESOURCE NAMES": "", + } + ) + else: + resource_table.append( + { + "RESOURCE TYPE": "", + "RESOURCE NAMES": "", + } + ) continue for resource in resource_model.resources: @@ -1756,8 +1425,10 @@ def print_service_connector_resource_table( resource.resource_type )[0] if resource.error: - # Error fetching resources - resource_ids = [f":collision: error: {resource.error}"] + errors_found = True + resource_ids = [ + f":collision: {_truncate_error(resource.error)}" + ] elif resource.resource_ids: resource_ids = resource.resource_ids else: @@ -1786,7 +1457,30 @@ def print_service_connector_resource_table( ) resource_table.append(resource_row) printed_connector = True + if i < len(resources) - 1 and resource_model.resources: + if not show_resources_only: + resource_table.append( + { + "CONNECTOR ID": "", + "CONNECTOR NAME": "", + "CONNECTOR TYPE": "", + "RESOURCE TYPE": "", + "RESOURCE NAMES": "", + } + ) + else: + resource_table.append( + { + "RESOURCE TYPE": "", + "RESOURCE NAMES": "", + } + ) print_table(resource_table) + if errors_found: + console.print( + "\n[dim]💡 Tip: Some error messages were truncated for readability. " + "Use describe commands for full error details.[/dim]" + ) def print_service_connector_configuration( @@ -1801,17 +1495,9 @@ def print_service_connector_configuration( active_status: Whether the connector is active. show_secrets: Whether to show secrets. """ - from uuid import UUID - from zenml.models import ServiceConnectorResponse - if connector.user: - if isinstance(connector.user, UUID): - user_name = str(connector.user) - else: - user_name = connector.user.name - else: - user_name = "-" + user_name = _get_user_name(connector.user) if isinstance(connector, ServiceConnectorResponse): declare( @@ -1832,7 +1518,6 @@ def print_service_connector_configuration( box=box.ROUNDED, title=title_, show_lines=True, - border_style="dim", ) rich_table.add_column("PROPERTY") rich_table.add_column("VALUE", overflow="fold") @@ -1907,7 +1592,6 @@ def print_service_connector_configuration( box=box.ROUNDED, title="Configuration", show_lines=True, - border_style="dim", ) rich_table.add_column("PROPERTY") rich_table.add_column("VALUE", overflow="fold") @@ -1936,7 +1620,6 @@ def print_service_connector_configuration( box=box.ROUNDED, title="Labels", show_lines=True, - border_style="dim", ) rich_table.add_column("LABEL") rich_table.add_column("VALUE", overflow="fold") @@ -2187,20 +1870,6 @@ def print_service_connector_type( return message -def _get_stack_components( - stack: "Stack", -) -> "List[StackComponent]": - """Get a dict of all components in a stack. - - Args: - stack: A stack - - Returns: - A list of all components in a stack. - """ - return list(stack.components.values()) - - def _scrub_secret(config: StackComponentConfig) -> Dict[str, Any]: """Remove secret values from a configuration. @@ -2277,6 +1946,24 @@ def _component_display_name( return name.replace("_", " ") +def _active_status( + is_active: bool, output_format: OutputFormat +) -> Union[str, bool]: + """Format active status based on output format. + + Args: + is_active: Whether the item is active. + output_format: The output format. + + Returns: + For table format: green dot if active, empty string if not. + For other formats: boolean value. + """ + if output_format == "table": + return "[green]●[/green]" if is_active else "" + return is_active + + def get_execution_status_emoji(status: "ExecutionStatus") -> str: """Returns an emoji representing the given execution status. @@ -2306,43 +1993,6 @@ def get_execution_status_emoji(status: "ExecutionStatus") -> str: raise RuntimeError(f"Unknown status: {status}") -def print_pipeline_runs_table( - pipeline_runs: Sequence["PipelineRunResponse"], -) -> None: - """Print a prettified list of all pipeline runs supplied to this method. - - Args: - pipeline_runs: List of pipeline runs - """ - runs_dicts = [] - for pipeline_run in pipeline_runs: - if pipeline_run.user: - user_name = pipeline_run.user.name - else: - user_name = "-" - - if pipeline_run.pipeline is None: - pipeline_name = "unlisted" - else: - pipeline_name = pipeline_run.pipeline.name - if pipeline_run.stack is None: - stack_name = "[DELETED]" - else: - stack_name = pipeline_run.stack.name - status = pipeline_run.status - status_emoji = get_execution_status_emoji(status) - run_dict = { - "PIPELINE NAME": pipeline_name, - "RUN NAME": pipeline_run.name, - "RUN ID": pipeline_run.id, - "STATUS": status_emoji, - "STACK": stack_name, - "OWNER": user_name, - } - runs_dicts.append(run_dict) - print_table(runs_dicts) - - def fetch_snapshot( snapshot_name_or_id: str, pipeline_name_or_id: Optional[str] = None, @@ -2415,103 +2065,299 @@ def get_deployment_status_emoji( return ":question:" -def format_deployment_status(status: Optional[str]) -> str: - """Format deployment status with color. +def _get_extra_columns_for_filter(filter_name: str) -> List[str]: + """Get extra columns added by row formatter functions. + + These columns are added by generate_*_row functions and aren't part + of the Response model itself. Args: - status: The deployment status. + filter_name: Name of the filter class (e.g., "StackFilter"). Returns: - Formatted status string. - """ - if status == DeploymentStatus.RUNNING: - return "[green]RUNNING[/green]" - elif status == DeploymentStatus.PENDING: - return "[yellow]PENDING[/yellow]" - elif status == DeploymentStatus.ERROR: - return "[red]ERROR[/red]" - elif status == DeploymentStatus.ABSENT: - return "[dim]ABSENT[/dim]" - - return "[dim]UNKNOWN[/dim]" + List of extra column names added by the row formatter. + """ + extra_columns_map: Dict[str, List[str]] = { + "StackFilter": [ct.value.lower() for ct in StackComponentType], + "ComponentFilter": ["flavor", "owner"], + "DeploymentFilter": [ + "pipeline", + "snapshot", + "url", + "status", + "stack", + "owner", + ], + "PipelineRunFilter": [ + "pipeline", + "run_name", + "status", + "stack", + "owner", + ], + "ServiceConnectorFilter": [ + "connector_type", + "resource_types", + "auth_method", + ], + } + return extra_columns_map.get(filter_name, []) -def print_deployment_table( - deployments: Sequence["DeploymentResponse"], -) -> None: - """Print a prettified list of all deployments supplied to this method. +def generate_deployment_row( + deployment: "DeploymentResponse", output_format: OutputFormat +) -> Dict[str, Any]: + """Generate additional data for deployment display. Args: - deployments: List of deployments - """ - deployment_dicts = [] - for deployment in deployments: - if deployment.user: - user_name = deployment.user.name - else: - user_name = "-" - - if deployment.snapshot is None or deployment.snapshot.pipeline is None: - pipeline_name = "unlisted" - else: - pipeline_name = deployment.snapshot.pipeline.name - if deployment.snapshot is None or deployment.snapshot.stack is None: - stack_name = "[DELETED]" - else: - stack_name = deployment.snapshot.stack.name - status = deployment.status or DeploymentStatus.UNKNOWN.value - status_emoji = get_deployment_status_emoji(status) - run_dict = { - "ID": deployment.id, - "NAME": deployment.name, - "PIPELINE": pipeline_name, - "SNAPSHOT": deployment.snapshot.name or "" - if deployment.snapshot - else "N/A", - "URL": deployment.url or "N/A", - "STATUS": f"{status_emoji} {status.upper()}", - "STACK": stack_name, - "OWNER": user_name, - } - deployment_dicts.append(run_dict) - print_table(deployment_dicts) + deployment: The deployment response. + output_format: The output format. - -def pretty_print_deployment( - deployment: "DeploymentResponse", - show_secret: bool = False, - show_metadata: bool = False, - show_schema: bool = False, - no_truncate: bool = False, -) -> None: - """Print a prettified deployment with organized sections. - - Args: - deployment: The deployment to print. - show_secret: Whether to show the auth key or mask it. - show_metadata: Whether to show the metadata. - show_schema: Whether to show the schema. - no_truncate: Whether to truncate the metadata. + Returns: + The additional data for the deployment. """ - # Header section - status_label = (deployment.status or "UNKNOWN").upper() - status_emoji = get_deployment_status_emoji(deployment.status) - declare( - f"\n[bold]Deployment:[/bold] [bold cyan]{deployment.name}[/bold cyan] status: {status_label} {status_emoji}" - ) - if deployment.snapshot is None: - pipeline_name = "N/A" - snapshot_name = "N/A" + user_name = _get_user_name(deployment.user) + + if deployment.snapshot is None or deployment.snapshot.pipeline is None: + pipeline_name = "unlisted" else: pipeline_name = deployment.snapshot.pipeline.name - snapshot_name = deployment.snapshot.name or str(deployment.snapshot.id) + if deployment.snapshot is None or deployment.snapshot.stack is None: stack_name = "[DELETED]" else: stack_name = deployment.snapshot.stack.name - declare(f"\n[bold]Pipeline:[/bold] [bold cyan]{pipeline_name}[/bold cyan]") - declare(f"[bold]Snapshot:[/bold] [bold cyan]{snapshot_name}[/bold cyan]") - declare(f"[bold]Stack:[/bold] [bold cyan]{stack_name}[/bold cyan]") + + status = deployment.status or DeploymentStatus.UNKNOWN.value + + if output_format == "table": + status_emoji = get_deployment_status_emoji(status) + status_display = f"{status_emoji} {status.upper()}" + else: + status_display = status.upper() + + return { + "pipeline": pipeline_name, + "snapshot": deployment.snapshot.name or "" + if deployment.snapshot + else "N/A", + "url": deployment.url or "N/A", + "status": status_display, + "stack": stack_name, + "owner": user_name, + } + + +def generate_stack_row( + stack: "StackResponse", + output_format: OutputFormat, + active_id: Optional["UUID"] = None, +) -> Dict[str, Any]: + """Generate row data for stack display. + + Args: + stack: The stack response. + output_format: The output format. + active_id: ID of the active stack for highlighting. + + Returns: + Dict with stack data for display. + """ + is_active = active_id is not None and stack.id == active_id + + row: Dict[str, Any] = { + "active": _active_status(is_active, output_format), + } + + for component_type in StackComponentType: + components = stack.components.get(component_type) + if output_format == "table": + header = component_type.value.upper().replace("_", " ") + else: + header = component_type.value + row[header] = components[0].name if components else "-" + + return row + + +def generate_project_row( + project: "ProjectResponse", + output_format: OutputFormat, + active_id: Optional["UUID"] = None, +) -> Dict[str, Any]: + """Generate row data for project display. + + Args: + project: The project response. + output_format: The output format. + active_id: ID of the active project for highlighting. + + Returns: + Dict with project data for display. + """ + is_active = active_id is not None and project.id == active_id + + return { + "active": _active_status(is_active, output_format), + } + + +def generate_user_row( + user: "UserResponse", + output_format: OutputFormat, + active_id: Optional["UUID"] = None, +) -> Dict[str, Any]: + """Generate row data for user display. + + Args: + user: The user response. + output_format: The output format. + active_id: ID of the active user for highlighting. + + Returns: + Dict with user data for display. + """ + is_active = active_id is not None and user.id == active_id + + return { + "active": _active_status(is_active, output_format), + } + + +def generate_pipeline_run_row( + pipeline_run: "PipelineRunResponse", + output_format: OutputFormat, +) -> Dict[str, Any]: + """Generate row data for pipeline run display. + + Args: + pipeline_run: The pipeline run response. + output_format: The output format. + + Returns: + Dict with pipeline run data for display. + """ + pipeline_name = ( + pipeline_run.pipeline.name if pipeline_run.pipeline else "unlisted" + ) + stack_name = pipeline_run.stack.name if pipeline_run.stack else "[DELETED]" + user_name = _get_user_name(pipeline_run.user) + status = pipeline_run.status + status_emoji = get_execution_status_emoji(status) + + return { + "pipeline": pipeline_name, + "run_name": pipeline_run.name, + "status": status_emoji if output_format == "table" else str(status), + "stack": stack_name, + "owner": user_name, + } + + +def generate_component_row( + component: "ComponentResponse", + output_format: OutputFormat, + active_id: Optional["UUID"] = None, +) -> Dict[str, Any]: + """Generate row data for component display. + + Args: + component: The component response. + output_format: The output format. + active_id: ID of the active component for highlighting. + + Returns: + Dict with component data for display. + """ + is_active = active_id is not None and component.id == active_id + + return { + "active": _active_status(is_active, output_format), + "name": component.name, + "component_id": component.id, + "flavor": component.flavor_name, + "owner": _get_user_name(component.user), + } + + +def generate_connector_row( + connector: "ServiceConnectorResponse", + output_format: OutputFormat, + active_connector_ids: Optional[List["UUID"]] = None, +) -> Dict[str, Any]: + """Generate row data for service connector display. + + Args: + connector: The service connector response. + output_format: The output format. + active_connector_ids: List of active connector IDs for highlighting. + + Returns: + Dict with connector data for display. + """ + is_active = bool( + active_connector_ids and connector.id in active_connector_ids + ) + labels = [f"{label}:{value}" for label, value in connector.labels.items()] + resource_name = connector.resource_id or "" + resource_types_str = "\n".join(connector.emojified_resource_types) + + return { + "active": _active_status(is_active, output_format), + "name": connector.name, + "id": connector.id, + "type": connector.emojified_connector_type, + "resource_types": resource_types_str, + "resource_name": resource_name, + "owner": _get_user_name(connector.user), + "expires_in": ( + expires_in( + connector.expires_at, + ":name_badge: Expired!", + connector.expires_skew_tolerance, + ) + if connector.expires_at + else "" + ), + "labels": "\n".join(labels), + } + + +def pretty_print_deployment( + deployment: "DeploymentResponse", + show_secret: bool = False, + show_metadata: bool = False, + show_schema: bool = False, + no_truncate: bool = False, +) -> None: + """Print a prettified deployment with organized sections. + + Args: + deployment: The deployment to print. + show_secret: Whether to show the auth key or mask it. + show_metadata: Whether to show the metadata. + show_schema: Whether to show the schema. + no_truncate: Whether to truncate the metadata. + """ + # Header section + status_label = (deployment.status or "UNKNOWN").upper() + status_emoji = get_deployment_status_emoji(deployment.status) + declare( + f"\n[bold]Deployment:[/bold] [bold cyan]{deployment.name}[/bold cyan] status: {status_label} {status_emoji}" + ) + if deployment.snapshot is None: + pipeline_name = "N/A" + snapshot_name = "N/A" + else: + pipeline_name = deployment.snapshot.pipeline.name + snapshot_name = deployment.snapshot.name or str(deployment.snapshot.id) + if deployment.snapshot is None or deployment.snapshot.stack is None: + stack_name = "[DELETED]" + else: + stack_name = deployment.snapshot.stack.name + declare(f"\n[bold]Pipeline:[/bold] [bold cyan]{pipeline_name}[/bold cyan]") + declare(f"[bold]Snapshot:[/bold] [bold cyan]{snapshot_name}[/bold cyan]") + declare(f"[bold]Stack:[/bold] [bold cyan]{stack_name}[/bold cyan]") # Connection section if deployment.url: @@ -2670,15 +2516,15 @@ def check_zenml_pro_project_availability() -> None: ) -def print_page_info(page: Page[T]) -> None: +def print_page_info(page: "Page[Any]") -> None: """Print all page information showing the number of items and pages. Args: - page: The page to print the information for. + page: The page object containing pagination information. """ declare( - f"Page `({page.index}/{page.total_pages})`, `{page.total}` items " - f"found for the applied filters." + f"Page `({page.index}/{page.total_pages})`, " + f"`{page.total}` items found for the applied filters." ) @@ -2795,21 +2641,92 @@ def _is_list_field(field_info: Any) -> bool: ) -def list_options(filter_model: Type[BaseFilter]) -> Callable[[F], F]: - """Create a decorator to generate the correct list of filter parameters. +def _get_response_columns_for_filter( + filter_model: Type[BaseFilter], +) -> List[str]: + """Get available column names by introspecting the Response model. + + Derives the Response model name from the Filter model name and extracts + field names from its body, metadata, and resources classes. Also includes + extra columns added by generate_*_row functions. + + Args: + filter_model: The filter model class. + + Returns: + List of available column names, or empty list if derivation fails. + """ + from typing import get_args as typing_get_args + from typing import get_origin as typing_get_origin + + import zenml.models as models_module + + filter_name = filter_model.__name__ + if not filter_name.endswith("Filter"): + return [] + + response_name = filter_name.replace("Filter", "Response") + response_class = getattr(models_module, response_name, None) + if response_class is None: + return [] + + columns: Set[str] = {"id"} + + if "name" in response_class.model_fields: + columns.add("name") + + for attr_name in ["body", "metadata", "resources"]: + field_info = response_class.model_fields.get(attr_name) + if field_info is None: + continue + annotation = field_info.annotation + origin = typing_get_origin(annotation) + if origin is not None: + args = typing_get_args(annotation) + for arg in args: + if arg is not type(None) and hasattr(arg, "model_fields"): + for field_name in arg.model_fields: + columns.add(field_name.lower().replace(" ", "_")) + break + elif hasattr(annotation, "model_fields"): + for field_name in annotation.model_fields: + columns.add(field_name.lower().replace(" ", "_")) + + extra_columns = _get_extra_columns_for_filter(filter_name) + columns.update(extra_columns) - The Outer decorator (`list_options`) is responsible for creating the inner - decorator. This is necessary so that the type of `FilterModel` can be passed - in as a parameter. + return sorted(columns) - Based on the filter model, the inner decorator extracts all the click - options that should be added to the decorated function (wrapper). + +def list_options( + filter_model: Type[BaseFilter], + default_columns: Optional[List[str]] = None, +) -> Callable[[F], F]: + """Add filter and output options to a list command. + + This decorator generates click options from a FilterModel and adds standard + output formatting options (--columns, --output). The decorated function + receives these as regular parameters - no magic interception! + + The function should call print_page() to render results. Args: - filter_model: The filter model based on which to decorate the function. + filter_model: The filter model to generate filter options from. + default_columns: Optional list of column names to use as defaults. Returns: - The inner decorator. + The decorator function. + + Example: + ```python + @list_options(StackFilter, default_columns=["id", "name"]) + def list_stacks(columns: str, output_format: str, **kwargs: Any) -> None: + stacks = Client().list_stacks(**kwargs) + if not stacks.items: + declare("No stacks found") + return + print_page(stacks, columns, output_format) + ``` """ def inner_decorator(func: F) -> F: @@ -2832,6 +2749,43 @@ def inner_decorator(func: F) -> F: create_data_type_help_text(filter_model, k) ) + default_columns_list = default_columns or [] + default_columns_str = ",".join(default_columns_list) + + derived_columns = _get_response_columns_for_filter(filter_model) + all_columns = sorted(set(derived_columns) | set(default_columns_list)) + columns_help = ( + "Comma-separated list of columns to display, or 'all' " + "for all columns." + ) + if all_columns: + columns_help += f" Available: {', '.join(all_columns)}." + + options.extend( + [ + click.option( + "--columns", + "-c", + type=str, + default=default_columns_str, + help=columns_help, + ), + click.option( + "--output", + "-o", + "output_format", + type=click.Choice(["table", "json", "yaml", "tsv", "csv"]), + default=get_default_output_format(), + help="Output format for the list.", + ), + ] + ) + + def wrapper(function: F) -> F: + for option in reversed(options): + function = option(function) + return function + func.__doc__ = ( f"{func.__doc__} By default all filters are " f"interpreted as a check for equality. However advanced " @@ -2852,17 +2806,7 @@ def inner_decorator(func: F) -> F: f"{joined_data_type_descriptors}" ) - for option in reversed(options): - func = option(func) - - @functools.wraps(func) - def wrapper(*args: Any, **kwargs: Any) -> Any: - nonlocal func - - kwargs = dict_utils.remove_none_values(kwargs) - return func(*args, **kwargs) - - return cast(F, wrapper) + return wrapper(func) return inner_decorator @@ -2894,19 +2838,6 @@ def temporary_active_stack( Client().activate_stack(old_stack_id) -def print_user_info(info: Dict[str, Any]) -> None: - """Print user information to the terminal. - - Args: - info: The information to print. - """ - for key, value in info.items(): - if key in ["packages", "query_packages"] and not bool(value): - continue - - declare(f"{key.upper()}: {value}") - - def get_parsed_labels( labels: Optional[List[str]], allow_label_only: bool = False ) -> Dict[str, Optional[str]]: @@ -2948,15 +2879,16 @@ def is_sorted_or_filtered(ctx: click.Context) -> bool: ctx: the Click context of the CLI call. Returns: - a boolean indicating whether any sorting or filtering parameters were - used during the list CLI call. + True if any parameter source differs from default, else False. """ + display_options = {"output_format", "columns"} try: - for _, source in ctx._parameter_source.items(): + for param, source in ctx._parameter_source.items(): + if param in display_options: + continue if source != click.core.ParameterSource.DEFAULT: return True return False - except Exception as e: logger.debug( f"There was a problem accessing the parameter source for " @@ -3018,9 +2950,8 @@ def multi_choice_prompt( table = Table( title=f"Available {object_type}", show_header=True, - border_style="dim", expand=True, - show_lines=True, + show_lines=False, ) table.add_column("Choice", justify="left", width=1) for h in headers: @@ -3050,8 +2981,7 @@ def multi_choice_prompt( if selected == "0" and allow_zero_be_a_new_object: return None - else: - return int(selected) - i_shift + return int(selected) - i_shift def requires_mac_env_var_warning() -> bool: @@ -3070,3 +3000,810 @@ def requires_mac_env_var_warning() -> bool: "OBJC_DISABLE_INITIALIZE_FORK_SAFETY" ) and mac_version_tuple >= (10, 13) return False + + +def get_default_output_format() -> OutputFormat: + """Get the default output format from environment variable. + + Returns: + The default output format, falling back to "table" if not configured + or if the configured value is invalid. + """ + value = os.environ.get(ENV_ZENML_DEFAULT_OUTPUT, "table") + valid_formats = get_args(OutputFormat) + if value in valid_formats: + return cast(OutputFormat, value) + return "table" + + +def prepare_response_data(item: AnyResponse) -> Dict[str, Any]: + """Prepare data from BaseResponse instances. + + This function extracts data from body, metadata, and resources of a + response model to create a flat dictionary suitable for CLI display. + It simplifies known nested objects (tags, components, user) to their + name representations. + + Args: + item: BaseResponse instance to format + + Returns: + Dictionary with the data + """ + + def _simplify_response(val: Any) -> Any: + """Simplify a value: Response -> name/id, list -> names, else as-is. + + Args: + val: Value to simplify + + Returns: + Simplified value + """ + if isinstance(val, BaseIdentifiedResponse): + return val.name if hasattr(val, "name") else str(val.id) + if isinstance(val, list) and val: + if isinstance(val[0], BaseIdentifiedResponse): + return [ + v.name if hasattr(v, "name") else str(v.id) for v in val + ] + if isinstance(val, dict) and val: + first_val = next(iter(val.values()), None) + if isinstance(first_val, list) and first_val: + if isinstance(first_val[0], BaseIdentifiedResponse): + return { + (k.value if hasattr(k, "value") else k): [ + v.name if hasattr(v, "name") else str(v.id) + for v in vs + ] + for k, vs in val.items() + } + return None + + def _process_model_fields(model: BaseModel) -> Dict[str, Any]: + """Extract fields, simplifying nested responses to names. + + Args: + model: Pydantic model to extract fields from + + Returns: + Dictionary with simplified field values + """ + result: Dict[str, Any] = {} + for field_name in type(model).model_fields: + val = getattr(model, field_name) + simplified = _simplify_response(val) + if simplified is not None: + result[field_name] = simplified + elif isinstance(val, BaseModel): + result[field_name] = val.model_dump(mode="json") + elif isinstance(val, UUID): + result[field_name] = str(val) + elif hasattr(val, "value"): # Enum + result[field_name] = val.value + else: + result[field_name] = val + return result + + item_data: Dict[str, Any] = {"id": str(item.id)} + + if "name" in type(item).model_fields: + item_data["name"] = getattr(item, "name") + + if item.body is not None: + item_data.update(_process_model_fields(item.body)) + + if item.metadata is not None: + item_data.update(_process_model_fields(item.metadata)) + + if item.resources is not None: + item_data.update(_process_model_fields(item.resources)) + + if isinstance(item, UserScopedResponse) and item.user: + item_data["user"] = item.user.name + + return item_data + + +def format_page_items( + page: Page[AnyResponse], + row_formatter: Optional[ + Callable[[Any, OutputFormat], Dict[str, Any]] + ] = None, + output_format: OutputFormat = "table", +) -> List[Dict[str, Any]]: + """Convert a Page of response models to a list of dicts for display. + + This is a lower-level helper that combines prepare_response_data with + optional custom formatting. For most use cases, prefer print_page() which + handles both formatting and output in a single call. + + Args: + page: Page of response items to convert. + row_formatter: Optional function to add custom fields to each row. + Should accept (item, output_format) and return a dict of additional fields. + output_format: Output format to pass to row_formatter (table/json/yaml/etc). + + Returns: + List of dicts ready to pass to handle_output. + + Example: + ```python + # Prefer print_page() for simple cases: + print_page(stacks_page, columns, output_format, generate_stack_row) + + # Use format_page_items() when you need to modify items before output: + items = format_page_items(stacks_page, generate_stack_row, output_format) + # ... modify items ... + handle_output(items, stacks_page, columns, output_format) + ``` + """ + result = [] + for item in page.items: + item_data = prepare_response_data(item) + if row_formatter: + additional_data = row_formatter(item, output_format) + if additional_data: + item_data.update(additional_data) + result.append(item_data) + return result + + +def print_page( + page: Page[AnyResponse], + columns: str, + output_format: OutputFormat, + row_formatter: Optional[ + Callable[[Any, OutputFormat], Dict[str, Any]] + ] = None, + column_aliases: Optional[Dict[str, str]] = None, + *, + empty_message: str = "No items found for this filter.", + row_generator: Optional[Callable[..., Dict[str, Any]]] = None, + active_id: Optional["UUID"] = None, +) -> None: + """Format and print a page of response items. + + This is a convenience function that combines format_page_items and + handle_output into a single call for cleaner CLI command implementations. + It also handles empty pages and active item highlighting automatically. + + Args: + page: Page of response items to display. + columns: Comma-separated column names. If empty, all columns are shown. + output_format: Output format (table, json, yaml, tsv, csv). + row_formatter: Optional function to add custom fields to each row. + Should accept (item, output_format) and return a dict of additional + fields. Use this for complex row formatting logic. + column_aliases: Optional mapping of original column names to display + names. Use this to rename columns in the table output. + empty_message: Message to display when the page has no items. + row_generator: Optional row generator function that accepts active_id. + When provided with active_id, creates a row_formatter automatically. + Use this for simple active item highlighting. + active_id: ID of the active item for highlighting. Used together with + row_generator to create the row_formatter. + """ + if not page.total: + declare(empty_message) + return + + if row_generator is not None: + row_formatter = partial(row_generator, active_id=active_id) + + items = format_page_items(page, row_formatter, output_format) + handle_output(items, page, columns, output_format, column_aliases) + + +def handle_output( + data: List[Dict[str, Any]], + page: Optional["Page[Any]"], + columns: str, + output_format: OutputFormat, + column_aliases: Optional[Dict[str, str]] = None, +) -> None: + """Handle output formatting for CLI commands. + + This function processes the output formatting parameters from CLI options + and calls the appropriate rendering function. + + Args: + data: List of dictionaries to render + page: Page object containing pagination info + output_format: Output format (table, json, yaml, tsv, csv). + columns: Comma-separated column names. If empty, all columns are shown. + column_aliases: Optional mapping of original column names to display + names. Use this to rename columns in the table output. + """ + cli_output = prepare_output( + data=data, + output_format=output_format, + columns=columns, + page=page, + column_aliases=column_aliases, + ) + if cli_output: + from zenml_cli import clean_output + + try: + clean_output(cli_output) + except (IOError, OSError) as err: + logger.warning("Failed to write clean output: %s", err) + print(cli_output) + + if page and output_format == "table": + print_page_info(page) + + +def prepare_output( + data: List[Dict[str, Any]], + output_format: OutputFormat = "table", + columns: Optional[str] = None, + page: Optional["Page[Any]"] = None, + column_aliases: Optional[Dict[str, str]] = None, +) -> str: + """Render data in specified format following ZenML CLI table guidelines. + + This function provides a centralized way to render tabular data across + all ZenML CLI commands with consistent formatting and multiple output + formats. + + Args: + data: List of dictionaries to render. + output_format: Output format (`table`, `json`, `yaml`, `tsv`, `csv`). + columns: Optional comma-separated list of column names to include. + Unrecognized column names will trigger a warning. + page: Optional page object for pagination metadata in JSON/YAML output. + column_aliases: Optional mapping of original column names to display + names. Use this to rename columns in the table output. + + Returns: + The rendered output in the specified format, or empty string if + no data is provided. + + Raises: + ValueError: If an unsupported output format is provided. + """ + if not data: + return "" + + available_keys = list(data[0].keys()) + + if columns and columns.strip().lower() == "all": + selected_columns = available_keys + elif columns: + requested_cols = [c.strip() for c in columns.split(",")] + col_mapping: Dict[str, str] = {} + unmatched_cols: List[str] = [] + + for req_col in requested_cols: + req_normalized = req_col.lower().replace("_", " ") + if req_col in available_keys: + col_mapping[req_col] = req_col + else: + matched = False + for key in available_keys: + key_normalized = key.lower().replace("_", " ") + if req_normalized == key_normalized: + col_mapping[req_col] = key + matched = True + break + if not matched: + unmatched_cols.append(req_col) + + if unmatched_cols: + normalized_keys = [ + key.lower().replace(" ", "_") for key in available_keys + ] + available_display = ", ".join(sorted(set(normalized_keys))) + warning( + f"Unknown column(s) ignored: {', '.join(unmatched_cols)}. " + f"Available: {available_display}" + ) + + selected_columns = list(col_mapping.values()) + else: + selected_columns = list(data[0].keys()) + + filtered_data = [ + {k: entry[k] for k in selected_columns if k in entry} for entry in data + ] + + pagination_dict = ( + { + "index": page.index, + "max_size": page.max_size, + "total_pages": page.total_pages, + "total": page.total, + } + if page + else None + ) + + if output_format == "json": + return _render_json(filtered_data, pagination=pagination_dict) + elif output_format == "yaml": + return _render_yaml(filtered_data, pagination=pagination_dict) + elif output_format == "tsv": + return _render_tsv(filtered_data) + elif output_format == "csv": + return _render_csv(filtered_data) + elif output_format == "table": + return _render_table(filtered_data, column_aliases=column_aliases) + else: + raise ValueError(f"Unsupported output format: {output_format}") + + +def _syntax_highlight(content: str, lexer: str) -> str: + """Apply syntax highlighting to content if colors are enabled. + + Syntax highlighting is only applied when output goes to an interactive + terminal (TTY). When output is redirected to a file or piped to another + program, plain text is returned to ensure machine-readable output. + + Args: + content: The text content to highlight + lexer: The lexer to use (e.g., "json", "yaml") + + Returns: + Syntax-highlighted string if colors enabled and output is a TTY, + otherwise the original content unchanged. + """ + if os.getenv("NO_COLOR"): + return content + + # Import here to avoid circular imports at module load time + from zenml_cli import is_terminal_output + + if not is_terminal_output(): + return content + + syntax = Syntax( + content, lexer, theme="ansi_dark", background_color="default" + ) + output_buffer = io.StringIO() + temp_console = Console(file=output_buffer, force_terminal=True) + temp_console.print(syntax) + return output_buffer.getvalue().rstrip() + + +def _render_json( + data: List[Dict[str, Any]], + pagination: Optional[Dict[str, Any]] = None, +) -> str: + """Render data as JSON. + + Args: + data: List of data dictionaries to render + pagination: Optional pagination metadata + + Returns: + JSON string representation of the data + """ + output: Dict[str, Any] = {"items": data} + + if pagination: + output["pagination"] = pagination + + json_str = json.dumps(output, indent=2, default=str) + return _syntax_highlight(json_str, "json") + + +def _render_yaml( + data: List[Dict[str, Any]], + pagination: Optional[Dict[str, Any]] = None, +) -> str: + """Render data as YAML. + + Args: + data: List of data dictionaries to render + pagination: Optional pagination metadata + + Returns: + YAML string representation of the data + """ + output: Dict[str, Any] = {"items": data} + + if pagination: + output["pagination"] = pagination + + yaml_str = yaml.dump(output, default_flow_style=False) + return _syntax_highlight(yaml_str, "yaml") + + +def _render_delimited( + data: List[Dict[str, Any]], + delimiter: str = ",", +) -> str: + """Render data as delimited values (CSV/TSV). + + Args: + data: List of data dictionaries to render + delimiter: Field delimiter character + + Returns: + Delimited string representation of the data + """ + if not data: + return "" + + output_buffer = io.StringIO() + headers = list(data[0].keys()) + writer = csv.DictWriter( + output_buffer, + fieldnames=headers, + delimiter=delimiter, + lineterminator="\n", + ) + writer.writeheader() + writer.writerows(data) + return output_buffer.getvalue().strip() + + +def _render_tsv(data: List[Dict[str, Any]]) -> str: + """Render data as TSV (Tab-Separated Values). + + Args: + data: List of data dictionaries to render + + Returns: + TSV string representation of the data + """ + return _render_delimited(data, delimiter="\t") + + +def _render_csv(data: List[Dict[str, Any]]) -> str: + """Render data as CSV (Comma-Separated Values). + + Args: + data: List of data dictionaries to render + + Returns: + CSV string representation of the data + """ + return _render_delimited(data, delimiter=",") + + +def _get_terminal_width() -> Optional[int]: + """Get terminal width from ZENML_CLI_COLUMN_WIDTH environment variable or shutil. + + Checks the ZENML_CLI_COLUMN_WIDTH environment variable first, then falls back + to shutil.get_terminal_size() for automatic detection. + + Returns: + Terminal width in characters, or None if cannot be determined + """ + columns_env = handle_int_env_var(ENV_ZENML_CLI_COLUMN_WIDTH, default=0) + if columns_env > 0: + return columns_env + + try: + size = shutil.get_terminal_size() + return size.columns + except (AttributeError, OSError): + return None + + +def _render_table( + data: List[Dict[str, Any]], + title: Optional[str] = None, + caption: Optional[str] = None, + column_aliases: Optional[Dict[str, str]] = None, +) -> str: + """Render data as a formatted table following ZenML guidelines. + + Args: + data: List of data dictionaries to render + title: Optional title for the table. + caption: Optional caption for the table. + column_aliases: Optional mapping of original column names to display + names. Use this to rename columns in the table output. + + Returns: + Formatted table string representation of the data + """ + aliases = column_aliases or {} + headers = list(data[0].keys()) + longest_values: Dict[str, int] = {} + for header in headers: + display_name = aliases.get(header, header) + header_display = display_name.replace("_", " ").upper() + longest_values[header] = max( + len(header_display), + max(len(str(row.get(header, ""))) for row in data), + ) + + terminal_width = _get_terminal_width() + console_width = ( + max(80, min(terminal_width, 200)) if terminal_width else 150 + ) + estimated_width = sum(longest_values.values()) + (len(headers) * 3) + + if estimated_width > console_width: + declare( + "Large tables may wrap, truncate, or hide columns depending on terminal " + "width.\n" + "- Use --output =json|yaml|csv|tsv for full data\n" + "- Or optionally limit visible columns with --columns\n" + ) + + rich_table = Table( + box=box.SIMPLE_HEAD, + show_header=True, + show_lines=False, + pad_edge=False, + collapse_padding=False, + expand=False, + show_edge=False, + header_style="bold", + title=title, + caption=caption, + ) + + for header in headers: + lower = header.lower().strip() + is_active_col = lower == "active" + is_id_col = _is_id_column(header) + is_name_col = _is_name_column(header) + display_name = aliases.get(header, header) + header_display = ( + "active" + if is_active_col + else display_name.replace("_", " ").upper() + ) + justify: Literal["default", "left", "center", "right", "full"] = "left" + overflow: Literal["fold", "crop", "ellipsis", "ignore"] = "ellipsis" + min_width: Optional[int] = None + no_wrap = False + + if is_active_col: + justify = "center" + overflow = "crop" + no_wrap = True + elif is_id_col or is_name_col: + overflow = "fold" + min_width = longest_values[header] + + rich_table.add_column( + header=header_display, + justify=justify, + overflow=overflow, + no_wrap=no_wrap, + min_width=min_width, + ) + + if data: + rich_table.add_section() + + for row in data: + values = [] + for header in headers: + value = str(row.get(header, "")) + + if not os.getenv("NO_COLOR"): + value = _colorize_value(header, value) + if value.startswith("http") and " " not in value: + value = f"[link={value}]{value}[/link]" + + values.append(value) + + rich_table.add_row(*values) + + output_buffer = io.StringIO() + table_console = Console( + width=console_width, + force_terminal=not os.getenv("NO_COLOR"), + no_color=os.getenv("NO_COLOR") is not None, + file=output_buffer, + ) + padded_table = Padding(rich_table, (0, 0, 1, 2)) + table_console.print(padded_table) + + return output_buffer.getvalue() + + +def _colorize_value(column: str, value: str) -> str: + """Apply colorization to values based on column type and content. + + Args: + column: Column name to determine colorization rules + value: Value to potentially colorize + + Returns: + Potentially colorized value with Rich markup + """ + if any( + keyword in column.lower() for keyword in ["status", "state", "health"] + ): + value_lower = value.lower() + green_statuses = { + "active", + "healthy", + "succeeded", + "completed", + "verified", + } + yellow_statuses = { + "running", + "pending", + "initializing", + "starting", + "warning", + "creating", + "updating", + } + red_statuses = { + "failed", + "error", + "unhealthy", + "stopped", + "crashed", + "deleted", + } + + if value_lower in green_statuses: + return f"[green]{value}[/green]" + elif value_lower in yellow_statuses: + return f"[yellow]{value}[/yellow]" + elif value_lower in red_statuses: + return f"[red]{value}[/red]" + + return value + + stripped = value.strip() + if stripped in ("-", ""): + display = stripped or "-" + return f"[dim]{display}[/dim]" + + return value + + +def _is_id_column(name: str) -> bool: + """Check if column name should be treated as an ID column. + + Args: + name: Column name to check + + Returns: + True if this is an ID column (word boundary matching) + """ + lower = name.lower().strip() + return lower == "id" or lower.endswith(" id") or lower.endswith("_id") + + +def _is_name_column(name: str) -> bool: + """Check if column name should be treated as a NAME column. + + Args: + name: Column name to check + + Returns: + True if this is a NAME column (word boundary matching) + """ + lower = name.lower().strip() + return ( + lower == "name" or lower.endswith(" name") or lower.endswith("_name") + ) + + +def _get_user_name(user: Union[UserResponse, UUID, None]) -> str: + """Get the name of a user. + + Args: + user: User object (UserResponse, UUID, or None) + + Returns: + Human-readable user name or '-' if unavailable + """ + if not user: + return "-" + + if isinstance(user, UserResponse): + return str(user.name) + + return str(user) + + +def _extract_model_columns( + model: BaseModel, + columns: Optional[Sequence[str]], + exclude_columns: Optional[Sequence[str]], +) -> List[str]: + """Extract column names from a model following BaseIdentifiedResponse semantics. + + Args: + model: The model to extract columns from + columns: Optional explicit list of columns to include + exclude_columns: Optional list of columns to exclude + + Returns: + List of column names to display + """ + exclude_list = list(exclude_columns or []) + + if columns: + return list(columns) + + if isinstance(model, BaseIdentifiedResponse): + include_columns: List[str] = ["id"] + + if "name" in type(model).model_fields: + include_columns.append("name") + + include_columns.extend( + [ + k + for k in type(model.get_body()).model_fields.keys() + if k not in exclude_list + ] + ) + + if model.metadata is not None: + include_columns.extend( + [ + k + for k in type(model.get_metadata()).model_fields.keys() + if k not in exclude_list + ] + ) + return include_columns + + return [k for k in model.model_dump().keys() if k not in exclude_list] + + +def _format_response_value(value: Any) -> Any: + """Format a value for display in pydantic model tables. + + Args: + value: The value to format + + Returns: + Formatted value (string or list of strings) + """ + if isinstance(value, BaseIdentifiedResponse): + if "name" in type(value).model_fields: + return str(getattr(value, "name")) + return str(value.id) + + if isinstance(value, list): + formatted_items: List[str] = [] + for v in value: + if isinstance(v, BaseIdentifiedResponse): + formatted_items.append(str(_format_response_value(v))) + else: + formatted_items.append(str(v)) + return formatted_items + + if isinstance(value, Set): + return [str(v) for v in value] + + return str(value) + + +def _print_key_value_table( + items: Dict[str, Any], + title: Optional[str] = None, + key_header: str = "PROPERTY", + capitalize_keys: bool = True, +) -> None: + """Print a simple key/value table with consistent formatting. + + Args: + items: Dictionary of key-value pairs to display + title: Optional table title + key_header: Header text for the key column + capitalize_keys: Whether to uppercase the keys + """ + rich_table = table.Table( + box=None, + title=title, + show_lines=False, + ) + rich_table.add_column(key_header, overflow="fold") + rich_table.add_column("VALUE", overflow="fold") + + for k, v in items.items(): + key_display = k.upper() if capitalize_keys else str(k) + rich_table.add_row(key_display, str(v)) + + console.print(rich_table) diff --git a/src/zenml/client.py b/src/zenml/client.py index d4cdc169683..189bf401182 100644 --- a/src/zenml/client.py +++ b/src/zenml/client.py @@ -8225,6 +8225,7 @@ def list_service_accounts( size: int = PAGE_SIZE_DEFAULT, logical_operator: LogicalOperators = LogicalOperators.AND, id: Optional[Union[UUID, str]] = None, + external_user_id: Optional[Union[UUID, str]] = None, created: Optional[Union[datetime, str]] = None, updated: Optional[Union[datetime, str]] = None, name: Optional[str] = None, @@ -8240,6 +8241,7 @@ def list_service_accounts( size: The maximum size of all pages logical_operator: Which logical operator to use [and, or] id: Use the id of stacks to filter by. + external_user_id: Use the external user id for filtering. created: Use to filter by time of creation updated: Use the last updated date for filtering name: Use the service account name for filtering @@ -8258,6 +8260,7 @@ def list_service_accounts( size=size, logical_operator=logical_operator, id=id, + external_user_id=external_user_id, created=created, updated=updated, name=name, diff --git a/src/zenml/constants.py b/src/zenml/constants.py index 583c5f5802f..893506c255c 100644 --- a/src/zenml/constants.py +++ b/src/zenml/constants.py @@ -217,6 +217,9 @@ def handle_int_env_var(var: str, default: int = 0) -> int: "ZENML_WORKLOAD_TOKEN_EXPIRATION_LEEWAY" ) +ENV_ZENML_DEFAULT_OUTPUT = "ZENML_DEFAULT_OUTPUT" +ENV_ZENML_CLI_COLUMN_WIDTH = "ZENML_CLI_COLUMN_WIDTH" + # Logging variables IS_DEBUG_ENV: bool = handle_bool_env_var(ENV_ZENML_DEBUG, default=False) diff --git a/src/zenml/integrations/gitlab/code_repositories/gitlab_code_repository.py b/src/zenml/integrations/gitlab/code_repositories/gitlab_code_repository.py index f40256f13b9..2a92fc4d1ea 100644 --- a/src/zenml/integrations/gitlab/code_repositories/gitlab_code_repository.py +++ b/src/zenml/integrations/gitlab/code_repositories/gitlab_code_repository.py @@ -196,7 +196,7 @@ def check_remote_url(self, url: str) -> bool: f"@{host}:" r"(?P\d+)?" r"(?(scheme_with_delimiter)/|/?)" - f"{group}/{project}(\.git)?$", + rf"{group}/{project}(\.git)?$", ) if ssh_regex.fullmatch(url): return True diff --git a/src/zenml_cli/__init__.py b/src/zenml_cli/__init__.py new file mode 100644 index 00000000000..86007e13224 --- /dev/null +++ b/src/zenml_cli/__init__.py @@ -0,0 +1,91 @@ +# Copyright (c) ZenML GmbH 2025. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express +# or implied. See the License for the specific language governing +# permissions and limitations under the License. +"""Core CLI functionality.""" + +import logging +import sys +from typing import List + +# Global variable to store original stdout for CLI clean output +_original_stdout = sys.stdout + + +def reroute_stdout() -> None: + """Reroute logging to stderr for CLI commands. + + This function redirects sys.stdout to sys.stderr so that all logging + output goes to stderr, while preserving the original stdout for clean + output that can be piped. + """ + modified_handlers: List[logging.StreamHandler] = [] + + # Reroute stdout to stderr + sys.stdout = sys.stderr + + # Handle existing root logger handlers that hold references to original stdout + for handler in logging.root.handlers: + if ( + isinstance(handler, logging.StreamHandler) + and handler.stream is _original_stdout + ): + handler.stream = sys.stderr + modified_handlers.append(handler) + + # Handle ALL existing individual logger handlers that hold references to original stdout + for _, logger in logging.Logger.manager.loggerDict.items(): + if isinstance(logger, logging.Logger): + for handler in logger.handlers: + if ( + isinstance(handler, logging.StreamHandler) + and handler.stream is _original_stdout + ): + handler.setStream(sys.stderr) + modified_handlers.append(handler) + + +def clean_output(text: str) -> None: + """Output text to stdout for clean piping, bypassing stderr rerouting. + + This function ensures that specific output goes to the original stdout + even when the CLI has rerouted stdout to stderr. This is useful for + outputting data that should be pipeable (like JSON, CSV, YAML) while + keeping logs and status messages in stderr. + + Args: + text: Text to output to stdout. + """ + _original_stdout.write(text) + if not text.endswith("\n"): + _original_stdout.write("\n") + _original_stdout.flush() + + +def is_terminal_output() -> bool: + """Check if the CLI output is going to an interactive terminal. + + This checks the original stdout (before CLI rerouting to stderr) to + determine if output will be displayed interactively or redirected + to a file/pipe. + + Returns: + True if output goes to an interactive terminal (TTY), False if + redirected to a file or piped to another program. + """ + return _original_stdout.isatty() + + +reroute_stdout() + +# Import the cli only after rerouting stdout +from zenml.cli.cli import cli # noqa: E402, F401 diff --git a/tests/integration/functional/cli/test_artifact.py b/tests/integration/functional/cli/test_artifact.py index 2a189c81d9d..98744085e72 100644 --- a/tests/integration/functional/cli/test_artifact.py +++ b/tests/integration/functional/cli/test_artifact.py @@ -21,7 +21,7 @@ def test_artifact_list(clean_client_with_run): """Test that `zenml artifact list` does not fail.""" runner = CliRunner() - list_command = cli.commands["pipeline"].commands["list"] + list_command = cli.commands["artifact"].commands["list"] result = runner.invoke(list_command) assert result.exit_code == 0 diff --git a/tests/integration/functional/cli/test_secret.py b/tests/integration/functional/cli/test_secret.py index 1d1a86f6a89..4acedc66584 100644 --- a/tests/integration/functional/cli/test_secret.py +++ b/tests/integration/functional/cli/test_secret.py @@ -83,25 +83,43 @@ def test_create_secret_with_values(): def test_list_secret_works(): """Test that the secret list command works.""" - runner = CliRunner() - with cleanup_secrets() as secret_name: - result1 = runner.invoke( - secret_list_command, - ) - assert result1.exit_code == 0 - assert secret_name not in result1.output + import io - runner = CliRunner() - runner.invoke( - secret_create_command, - [secret_name, "--test_value=aria", "--test_value2=axl"], - ) + import zenml_cli - result2 = runner.invoke( - secret_list_command, - ) - assert result2.exit_code == 0 - assert secret_name in result2.output + # Save original _original_stdout for cleanup + original_stdout = zenml_cli._original_stdout + + runner = CliRunner(mix_stderr=False) + try: + with cleanup_secrets() as secret_name: + # Capture clean_output writes by replacing _original_stdout + # with a StringIO buffer before each CLI invocation + buffer1 = io.StringIO() + zenml_cli._original_stdout = buffer1 + + result1 = runner.invoke(secret_list_command) + output1 = buffer1.getvalue() + result1.output + + assert result1.exit_code == 0 + assert secret_name not in output1 + + runner.invoke( + secret_create_command, + [secret_name, "--test_value=aria", "--test_value2=axl"], + ) + + buffer2 = io.StringIO() + zenml_cli._original_stdout = buffer2 + + result2 = runner.invoke(secret_list_command) + output2 = buffer2.getvalue() + result2.output + + assert result2.exit_code == 0 + assert secret_name in output2 + finally: + # Restore original state + zenml_cli._original_stdout = original_stdout def test_get_secret_works(): @@ -251,7 +269,6 @@ def test_delete_secret_works(): def test_rename_secret_works(): """Test that the secret rename command works.""" - runner = CliRunner() with cleanup_secrets() as secret_name: diff --git a/tests/integration/functional/cli/test_utils.py b/tests/integration/functional/cli/test_utils.py index 4d17640f302..750dc387291 100644 --- a/tests/integration/functional/cli/test_utils.py +++ b/tests/integration/functional/cli/test_utils.py @@ -89,17 +89,6 @@ def test_converting_structured_str_to_dict(): ) -def test_parsing_unknown_component_attributes(): - """Test that our ability to parse CLI arguments works.""" - assert cli_utils.parse_unknown_component_attributes( - ["--foo", "--bar", "--baz", "--qux"] - ) == ["foo", "bar", "baz", "qux"] - with pytest.raises(AssertionError): - cli_utils.parse_unknown_component_attributes(["foo"]) - with pytest.raises(AssertionError): - cli_utils.parse_unknown_component_attributes(["foo=bar=qux"]) - - def test_validate_keys(): """Test that validation of proper identifier as key works""" with pytest.raises(ClickException):