# Integrations
> This bundle contains all pages in the Integrations section.
> Source: https://www.union.ai/docs/v2/union/integrations/

=== PAGE: https://www.union.ai/docs/v2/union/integrations ===

# Integrations

> **📝 Note**
>
> An LLM-optimized bundle of this entire section is available at [`section.md`](section.md).
> This single file contains all pages in this section, optimized for AI coding agent context.

Flyte 2 is designed to be extensible by default. While the core platform covers the most common orchestration needs, many production workloads require specialized infrastructure, external services or execution semantics that go beyond the core runtime.

Flyte 2 exposes these capabilities through integrations.

Under the hood, integrations are implemented using Flyte 2's plugin system, which provides a consistent way to extend the platform without modifying core execution logic.

An integration allows you to declaratively enable new capabilities such as distributed compute frameworks or third-party services without manually managing infrastructure. You specify what you need, and Flyte takes care of how it is provisioned, used and cleaned up.

This page covers:

- The types of integrations Flyte 2 supports today
- How integrations fit into Flyte 2's execution model
- How to use integrations in your tasks
- The integrations available out of the box

If you need functionality that doesn't exist yet, Flyte 2's plugin system is intentionally open-ended. You can build and register your own integrations using the same architecture described here.

## Integration categories

Flyte 2 integrations fall into the following categories:

1. **Distributed compute**: Provision transient compute clusters to run tasks across multiple nodes, with automatic lifecycle management.
2. **Agentic AI**: Support for various common aspects of agentic AI applications.
3. **Configuration**: Compose and pass hierarchical configuration objects between tasks, with type-safe schemas and CLI/YAML composition.
4. **Experiment tracking**: Integrate with experiment tracking platforms for logging metrics, parameters, and artifacts.
5. **Data validation**: Enforce schema contracts on dataframes flowing between tasks, with automatic validation reports.
6. **Connectors**: Stateless, long-running services that receive execution requests via gRPC and then submit work to external (or internal) systems.
7. **LLM Serving**: Deploy and serve large language models with an OpenAI-compatible API.

## Distributed compute

Distributed compute integrations allow tasks to run on dynamically provisioned clusters. These clusters are created just-in-time, scoped to the task execution and torn down automatically when the task completes.

This enables large-scale parallelism without requiring users to operate or maintain long-running infrastructure.

### Supported distributed compute integrations

| Plugin               | Description                                      | Common use cases                                       |
| -------------------- | ------------------------------------------------ | ------------------------------------------------------ |
| [Ray](./ray/_index)         | Provisions Ray clusters via KubeRay              | Distributed Python, ML training, hyperparameter tuning |
| [Spark](./spark/_index)     | Provisions Spark clusters via Spark Operator     | Large-scale data processing, ETL pipelines             |
| [Dask](./dask/_index)       | Provisions Dask clusters via Dask Operator       | Parallel Python workloads, dataframe operations        |
| [PyTorch](./pytorch/_index) | Distributed PyTorch training with elastic launch | Single-node and multi-node training                    |

Each plugin encapsulates:

- Cluster provisioning
- Resource configuration
- Networking and service discovery
- Lifecycle management and teardown

From the task author's perspective, these details are abstracted away.

### How the plugin system works

At a high level, Flyte 2's distributed compute plugin architecture follows a simple and consistent pattern.

#### 1. Registration

Each plugin registers itself with Flyte 2's core plugin registry:

- **`TaskPluginRegistry`**: The central registry for all distributed compute plugins
- Each plugin declares:
  - Its configuration schema
  - How that configuration maps to execution behavior

This registration step makes the plugin discoverable by the runtime.

#### 2. Task environments and plugin configuration

Integrations are activated through a `TaskEnvironment`.

A `TaskEnvironment` bundles:

- A container image
- Execution settings
- A plugin configuration object enabled with `plugin_config`

The plugin configuration describes _what_ infrastructure or integration the task requires.

#### 3. Automatic provisioning and execution

When a task associated with a `TaskEnvironment` runs:

1. Flyte inspects the environment's plugin configuration
2. The plugin provisions the required infrastructure or integration
3. The task executes with access to that capability
4. Flyte cleans up all transient resources after completion

### Example: Using the Dask plugin

Below is a complete example showing how a task gains access to a Dask cluster simply by running inside an environment configured with the Dask plugin.

```python
from flyteplugins.dask import Dask, WorkerGroup
import flyte

# Define the Dask cluster configuration
dask_config = Dask(
    workers=WorkerGroup(number_of_workers=4)
)

# Create a task environment that enables Dask
env = flyte.TaskEnvironment(
    name="dask_env",
    plugin_config=dask_config,
    image=image,
)

# Any task in this environment has access to the Dask cluster
@env.task
async def process_data(data: list) -> list:
    from distributed import Client

    client = Client()  # Automatically connects to the provisioned cluster
    futures = client.map(transform, data)
    return client.gather(futures)
```

When `process_data` executes, Flyte performs the following steps:

1. Provisions a Dask cluster with 4 workers
2. Executes the task with network access to the cluster
3. Tears down the cluster once the task completes

No cluster management logic appears in the task code. The task only expresses intent.

### Key design principle

All distributed compute integrations follow the same mental model:

- You declare the required capability via configuration
- You attach that configuration to a task environment
- Tasks decorated with that environment automatically gain access to the capability

This makes it easy to swap execution backends or introduce distributed compute incrementally without rewriting workflows.

## Agentic AI

Agentic AI integrations provide drop-in replacements for LLM provider SDKs. They let you use Flyte tasks as agent tools so that tool calls run with full Flyte observability, retries, and caching.

### Supported agentic AI integrations

| Plugin                              | Description                                                  | Common use cases                     |
| ----------------------------------- | ------------------------------------------------------------ | ------------------------------------ |
| [OpenAI](./openai/_index)           | Drop-in replacement for OpenAI Agents SDK `function_tool`    | Agentic workflows with OpenAI models |
| [Anthropic](./anthropic/_index)     | Agent loop and `function_tool` for the Anthropic Claude SDK  | Agentic workflows with Claude        |
| [Gemini](./gemini/_index)           | Agent loop and `function_tool` for the Google Gemini SDK     | Agentic workflows with Gemini        |
| [Code generation](./codegen/_index) | LLM-driven code generation with automatic testing in sandboxes | Data processing, ETL, analysis pipelines |

## Experiment tracking

Experiment tracking integrations let you log metrics, parameters, and artifacts to external tracking platforms during Flyte task execution.

### Supported experiment tracking integrations

| Plugin                               | Description                  | Common use cases                              |
| ------------------------------------ | ---------------------------- | --------------------------------------------- |
| [MLflow](./mlflow/_index)            | MLflow experiment tracking   | Experiment tracking, autologging, model registry |
| [Weights and Biases](./wandb/_index) | Weights & Biases integration | Experiment tracking and hyperparameter tuning |

## Configuration

Configuration integrations let you compose and pass hierarchical configuration objects between Flyte tasks, with type-safe schemas and CLI/YAML composition.

### Supported configuration integrations

| Plugin                              | Description                                                | Common use cases                                            |
| ----------------------------------- | ---------------------------------------------------------- | ----------------------------------------------------------- |
| [OmegaConf](./omegaconf/_index)     | `DictConfig` / `ListConfig` as native task input and output types | Passing composed configs between tasks, structured configs, YAML-driven pipelines |
| [Hydra](./hydra/_index)             | Hydra config composition and sweep submission for Flyte tasks    | YAML-driven experiment composition, grid and Bayesian sweeps, hardware presets |

## Data validation

Data validation integrations enforce schema contracts on the dataframes flowing between tasks. They validate data at task boundaries, catch type and constraint violations early, and produce HTML reports visible in the Flyte UI.

### Supported data validation integrations

| Plugin                         | Description                                         | Common use cases                                          |
| ------------------------------ | --------------------------------------------------- | --------------------------------------------------------- |
| [Pandera](./pandera/_index)    | Validates dataframes with pandera `DataFrameModel` schemas | Schema enforcement, data quality checks, validation reports |

## Connectors

Connectors are stateless, long‑running services that receive execution requests via gRPC and then submit work to external (or internal) systems. Each connector runs as its own Kubernetes deployment, and is triggered when a Flyte task of the matching type is executed.

Although they normally run inside the control plane, you can also run connectors locally as long as the required secrets/credentials are present locally. This is useful because connectors are just Python services that can be spawned in‑process.

Connectors are designed to scale horizontally and reduce load on the core Flyte backend because they execute _outside_ the core system. This decoupling makes connectors efficient, resilient, and easy to iterate on. You can even test them locally without modifying backend configuration, which reduces friction during development.

### Supported connectors

| Connector                          | Description                                    | Common use cases                         |
| ---------------------------------- | ---------------------------------------------- | ---------------------------------------- |
| [Snowflake](./snowflake/_index)    | Run SQL queries on Snowflake asynchronously    | Data warehousing, ETL, analytics queries |
| [BigQuery](./bigquery/_index)      | Run SQL queries on Google BigQuery             | Data warehousing, ETL, analytics queries |
| [Databricks](./databricks/_index)  | Run PySpark jobs on Databricks clusters        | Large-scale data processing, Spark ETL   |

### Creating a new connector

If none of the existing connectors meet your needs, you can build your own.

> [!NOTE]
> Connectors communicate via Protobuf, so in theory they can be implemented in any language.
> Today, only **Python** connectors are supported.

### Async connector interface

To implement a new async connector, extend `AsyncConnector` and implement the following methods, all of which must be idempotent:

| Method   | Purpose                                                     |
| -------- | ----------------------------------------------------------- |
| `create` | Launch the external job (via REST, gRPC, SDK, or other API) |
| `get`    | Fetch current job state (return job status or output)       |
| `delete` | Delete / cancel the external job                            |

To test the connector locally, the connector task should inherit from
[AsyncConnectorExecutorMixin](https://github.com/flyteorg/flyte-sdk/blob/1d49299294cd5e15385fe8c48089b3454b7a4cd1/src/flyte/connectors/_connector.py#L206). This mixin simulates how the Flyte 2 system executes asynchronous connector tasks, making it easier to validate your connector implementation before deploying it.

### Example: Model training connector

The following example implements a connector that launches a model training job on an external training service.

```python
import typing
from dataclasses import dataclass

import httpx
from flyte.connectors import AsyncConnector, Resource, ResourceMeta
from flyteidl2.core.execution_pb2 import TaskExecution, TaskLog
from flyteidl2.core.tasks_pb2 import TaskTemplate
from google.protobuf import json_format

@dataclass
class ModelTrainJobMeta(ResourceMeta):
    job_id: str
    endpoint: str

class ModelTrainingConnector(AsyncConnector):
    """
    Example connector that launches a ML model training job on an external training service.

    POST → launch training job
    GET  → poll training progress
    DELETE → cancel training job
    """

    name = "Model Training Connector"
    task_type_name = "external_model_training"
    metadata_type = ModelTrainJobMeta

    async def create(
        self,
        task_template: TaskTemplate,
        inputs: typing.Optional[typing.Dict[str, typing.Any]],
        **kwargs,
    ) -> ModelTrainJobMeta:
        """
        Submit training job via POST.
        Response returns job_id we later use in get().
        """
        custom = json_format.MessageToDict(task_template.custom) if task_template.custom else None
        async with httpx.AsyncClient() as client:
            r = await client.post(
                custom["endpoint"],
                json={"dataset_uri": inputs["dataset_uri"], "epochs": inputs["epochs"]},
            )
        r.raise_for_status()
        return ModelTrainJobMeta(job_id=r.json()["job_id"], endpoint=custom["endpoint"])

    async def get(self, resource_meta: ModelTrainJobMeta, **kwargs) -> Resource:
        """
        Poll external API until training job finishes.
        Must be safe to call repeatedly.
        """
        async with httpx.AsyncClient() as client:
            r = await client.get(f"{resource_meta.endpoint}/{resource_meta.job_id}")

        data = r.json()

        if data["status"] == "finished":
            return Resource(
                phase=TaskExecution.SUCCEEDED,
                log_links=[TaskLog(name="training-dashboard", uri=f"https://example-mltrain.com/train/{resource_meta.job_id}")],
                outputs={"results": data["results"]},
            )

        return Resource(phase=TaskExecution.RUNNING)

    async def delete(self, resource_meta: ModelTrainJobMeta, **kwargs):
        """
        Optionally call DELETE on external API.
        Safe even if job already completed.
        """
        async with httpx.AsyncClient() as client:
            await client.delete(f"{resource_meta.endpoint}/{resource_meta.job_id}")
```

To use this connector, you should define a task whose `task_type` matches the connector.

```python
import flyte.io
from typing import Any, Dict, Optional

from flyte.extend import TaskTemplate
from flyte.connectors import AsyncConnectorExecutorMixin
from flyte.models import NativeInterface, SerializationContext

class ModelTrainTask(AsyncConnectorExecutorMixin, TaskTemplate):
    _TASK_TYPE = "external_model_training"

    def __init__(
        self,
        name: str,
        endpoint: str,
        **kwargs,
    ):
        super().__init__(
            name=name,
            interface=NativeInterface(
                inputs={"epochs": int, "dataset_uri": str},
                outputs={"results": flyte.io.File},
            ),
            task_type=self._TASK_TYPE,
            **kwargs,
        )
        self.endpoint = endpoint

    def custom_config(self, sctx: SerializationContext) -> Optional[Dict[str, Any]]:
        return {"endpoint": self.endpoint}
```

Here is an example of how to use the `ModelTrainTask`:

```python
import flyte
from flyteplugins.model_training import ModelTrainTask

model_train_task = ModelTrainTask(
    name="model_train",
    endpoint="https://example-mltrain.com",
)

model_train_env = flyte.TaskEnvironment.from_task("model_train_env", model_train_task)

env = flyte.TaskEnvironment(
    name="hello_world",
    resources=flyte.Resources(memory="250Mi"),
    image=flyte.Image.from_debian_base(name="model_training").with_pip_packages(
        "flyteplugins-model-training", pre=True
    ),
    depends_on=[model_train_env],
)

@env.task
def data_prep() -> str:
    return "gs://my-bucket/dataset.csv"

@env.task
def train_model(epochs: int) -> flyte.io.File:
    dataset_uri = data_prep()
    return model_train_task(epochs=epochs, dataset_uri=dataset_uri)
```

### Build a custom connector image

Build a custom image when you're ready to deploy your connector to your cluster.
To build the Docker image for your connector, run the following script:

```python
import asyncio
from flyte import Image
from flyte.extend import ImageBuildEngine

async def build_flyte_connector_bigquery_image(registry: str, name: str, builder: str = "local"):
    """
    Build the SDK default connector image optionally overriding
    the container registry and image name.

    Args:
        registry: e.g. "ghcr.io/my-org" or "123456789012.dkr.ecr.us-west-2.amazonaws.com".
        name:     e.g. "my-connector".
        builder:  e.g. "local" or "remote".
    """

    default_image = Image.from_debian_base(
        registry=registry, name=name
    ).with_pip_packages("flyteintegrations-bigquery", pre=True)
    await ImageBuildEngine.build(default_image, builder=builder)

if __name__ == "__main__":
    print("Building connector image...")
    asyncio.run(
        build_flyte_connector_bigquery_image(
            registry="<YOUR_REGISTRY>", name="flyte-bigquery", builder="local"
        )
    )
```

## LLM Serving

LLM serving integrations let you deploy and serve large language models as Flyte apps with an OpenAI-compatible API. They handle model loading, GPU management, and autoscaling.

### Supported LLM serving integrations

| Plugin                                                            | Description                                           | Common use cases                     |
| ----------------------------------------------------------------- | ----------------------------------------------------- | ------------------------------------ |
| [SGLang](https://www.union.ai/docs/v2/union/user-guide/build-apps/sglang-app/page.md)  | Deploy models with SGLang's high-throughput runtime   | LLM inference, model serving         |
| [vLLM](https://www.union.ai/docs/v2/union/user-guide/build-apps/vllm-app/page.md)      | Deploy models with vLLM's PagedAttention engine       | LLM inference, model serving         |

For full setup instructions including multi-GPU deployment, model prefetching, and autoscaling, see the [SGLang app](https://www.union.ai/docs/v2/union/user-guide/build-apps/sglang-app/page.md) and [vLLM app](https://www.union.ai/docs/v2/union/user-guide/build-apps/vllm-app/page.md) pages.

=== PAGE: https://www.union.ai/docs/v2/union/integrations/anthropic ===

# Anthropic

The Anthropic plugin lets you build agentic workflows with [Claude](https://www.anthropic.com/) on Flyte. It provides a `function_tool` decorator that wraps Flyte tasks as tools that Claude can call, and a `run_agent` function that drives the agent conversation loop.

When Claude calls a tool, the call executes as a Flyte task with full observability, retries, and caching.

## Installation

```bash
pip install flyteplugins-anthropic
```

Requires `anthropic >= 0.40.0`.

## Quick start

```python
import flyte
from flyteplugins.anthropic import function_tool, run_agent

env = flyte.TaskEnvironment(
    name="claude-agent",
    resources=flyte.Resources(cpu=1, memory="250Mi"),
    image=flyte.Image.from_uv_script(__file__, name="anthropic_agent"),
    secrets=flyte.Secret("anthropic_api_key", as_env_var="ANTHROPIC_API_KEY"),
)

@function_tool
@env.task
async def get_weather(city: str) -> str:
    """Get the current weather for a city."""
    return f"The weather in {city} is sunny, 72F"

@env.task
async def main(prompt: str) -> str:
    tools = [get_weather]
    return await run_agent(prompt=prompt, tools=tools)
```

## API

### `function_tool`

Converts a Flyte task, `@flyte.trace`-decorated function, or plain callable into a tool that Claude can invoke.

```python
@function_tool
@env.task
async def my_tool(param: str) -> str:
    """Tool description sent to Claude."""
    ...
```

Can also be called with optional overrides:

```python
@function_tool(name="custom_name", description="Custom description")
@env.task
async def my_tool(param: str) -> str:
    ...
```

Parameters:

| Parameter | Type | Description |
|-----------|------|-------------|
| `func` | callable | The function to wrap |
| `name` | `str` | Override the tool name (defaults to the function name) |
| `description` | `str` | Override the tool description (defaults to the docstring) |

> [!NOTE]
> The docstring on each `@function_tool` task is sent to Claude as the tool description. Write clear, concise docstrings.

### `Agent`

A dataclass for bundling agent configuration:

```python
from flyteplugins.anthropic import Agent

agent = Agent(
    name="my-agent",
    instructions="You are a helpful assistant.",
    model="claude-sonnet-4-20250514",
    tools=[get_weather],
    max_tokens=4096,
    max_iterations=10,
)
```

| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `name` | `str` | `"assistant"` | Agent name |
| `instructions` | `str` | `"You are a helpful assistant."` | System prompt |
| `model` | `str` | `"claude-sonnet-4-20250514"` | Claude model ID |
| `tools` | `list[FunctionTool]` | `[]` | Tools available to the agent |
| `max_tokens` | `int` | `4096` | Maximum tokens per response |
| `max_iterations` | `int` | `10` | Maximum tool-call loop iterations |

### `run_agent`

Runs a Claude conversation loop, dispatching tool calls to Flyte tasks until Claude returns a final response.

```python
result = await run_agent(
    prompt="What's the weather in Tokyo?",
    tools=[get_weather],
    model="claude-sonnet-4-20250514",
)
```

You can also pass an `Agent` object:

```python
result = await run_agent(prompt="What's the weather?", agent=agent)
```

| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `prompt` | `str` | required | User message |
| `tools` | `list[FunctionTool]` | `None` | Tools available to the agent |
| `agent` | `Agent` | `None` | Agent config (overrides individual params) |
| `model` | `str` | `"claude-sonnet-4-20250514"` | Claude model ID |
| `system` | `str` | `None` | System prompt |
| `max_tokens` | `int` | `4096` | Maximum tokens per response |
| `max_iterations` | `int` | `10` | Maximum iterations (prevents infinite loops) |
| `api_key` | `str` | `None` | API key (falls back to `ANTHROPIC_API_KEY` env var) |

## Secrets

Store your Anthropic API key as a Flyte secret and expose it as an environment variable:

```python
secrets=flyte.Secret("anthropic_api_key", as_env_var="ANTHROPIC_API_KEY")
```

## API reference

See the [Anthropic API reference](https://www.union.ai/docs/v2/union/api-reference/integrations/anthropic/_index) for full details.

=== PAGE: https://www.union.ai/docs/v2/union/integrations/bigquery ===

# BigQuery

The BigQuery connector lets you run SQL queries against [Google BigQuery](https://cloud.google.com/bigquery) directly from Flyte tasks. Queries are submitted asynchronously via the BigQuery Jobs API and polled for completion, so they don't block a worker while waiting for results.

The connector supports:

- Parameterized SQL queries with typed inputs
- Google Cloud service account authentication
- Returns query results as DataFrames
- Query cancellation on task abort

## Installation

```bash
pip install flyteplugins-bigquery
```

This installs the Google Cloud BigQuery client libraries.

## Quick start

Here's a minimal example that runs a SQL query on BigQuery:

```python
from flyte.io import DataFrame
from flyteplugins.bigquery import BigQueryConfig, BigQueryTask

config = BigQueryConfig(
    ProjectID="my-gcp-project",
    Location="US",
)

count_users = BigQueryTask(
    name="count_users",
    query_template="SELECT COUNT(*) FROM dataset.users",
    plugin_config=config,
    output_dataframe_type=DataFrame,
)
```

This defines a task called `count_users` that runs the query on the configured BigQuery instance. When executed, the connector:

1. Connects to BigQuery using the provided configuration
2. Submits the query asynchronously via the Jobs API
3. Polls until the query completes or fails

To run the task, create a `TaskEnvironment` from it and execute it locally or remotely:

```python
import flyte

bigquery_env = flyte.TaskEnvironment.from_task("bigquery_env", count_users)

if __name__ == "__main__":
    flyte.init_from_config()

    # Run locally (connector runs in-process, requires credentials locally)
    run = flyte.with_runcontext(mode="local").run(count_users)

    # Run remotely (connector runs on the control plane)
    run = flyte.with_runcontext(mode="remote").run(count_users)

    print(run.url)
```

> [!NOTE]
> The `TaskEnvironment` created by `from_task` does not need an image or pip packages. BigQuery tasks are connector tasks, which means the query executes on the connector service, not in your task container. In `local` mode, the connector runs in-process and requires `flyteplugins-bigquery` and credentials to be available on your machine.

## Configuration

### `BigQueryConfig` parameters

| Field | Type | Required | Description |
|-------|------|----------|-------------|
| `ProjectID` | `str` | Yes | GCP project ID |
| `Location` | `str` | No | BigQuery region (e.g., `"US"`, `"EU"`) |
| `QueryJobConfig` | `bigquery.QueryJobConfig` | No | Native BigQuery [QueryJobConfig](https://cloud.google.com/python/docs/reference/bigquery/latest/google.cloud.bigquery.job.QueryJobConfig) object for advanced settings |

### `BigQueryTask` parameters

| Parameter | Type | Description |
|-----------|------|-------------|
| `name` | `str` | Unique task name |
| `query_template` | `str` | SQL query (whitespace is normalized before execution) |
| `plugin_config` | `BigQueryConfig` | Connection configuration |
| `inputs` | `Dict[str, Type]` | Named typed inputs bound as query parameters |
| `output_dataframe_type` | `Type[DataFrame]` | If set, query results are returned as a `DataFrame` |
| `google_application_credentials` | `str` | Name of the Flyte secret containing the GCP service account JSON key |

## Authentication

Pass the name of a Flyte secret containing your GCP service account JSON key:

```python
query = BigQueryTask(
    name="secure_query",
    query_template="SELECT * FROM dataset.sensitive_data",
    plugin_config=config,
    google_application_credentials="my-gcp-sa-key",
)
```

## Query templating

Use the `inputs` parameter to define typed inputs for your query. Input values are bound as BigQuery `ScalarQueryParameter` values.

### Supported input types

| Python type | BigQuery type |
|-------------|---------------|
| `int` | `INT64` |
| `float` | `FLOAT64` |
| `str` | `STRING` |
| `bool` | `BOOL` |
| `bytes` | `BYTES` |
| `datetime` | `DATETIME` |
| `list` | `ARRAY` |

### Parameterized query example

```python
from flyte.io import DataFrame

events_by_region = BigQueryTask(
    name="events_by_region",
    query_template="SELECT * FROM dataset.events WHERE region = @region AND score > @min_score",
    plugin_config=config,
    inputs={"region": str, "min_score": float},
    output_dataframe_type=DataFrame,
)
```

> [!NOTE]
> The query template is normalized before execution: newlines and tabs are replaced with spaces and consecutive whitespace is collapsed. You can format your queries across multiple lines for readability without affecting execution.

## Retrieving query results

Set `output_dataframe_type` to capture results as a DataFrame:

```python
from flyte.io import DataFrame

top_customers = BigQueryTask(
    name="top_customers",
    query_template="""
        SELECT customer_id, SUM(amount) AS total_spend
        FROM dataset.orders
        GROUP BY customer_id
        ORDER BY total_spend DESC
        LIMIT 100
    """,
    plugin_config=config,
    output_dataframe_type=DataFrame,
)
```

If you don't need query results (for example, DDL statements or INSERT queries), omit `output_dataframe_type`.

## API reference

See the [BigQuery API reference](https://www.union.ai/docs/v2/union/api-reference/integrations/bigquery/_index) for full details.

=== PAGE: https://www.union.ai/docs/v2/union/integrations/codegen ===

# Code generation

The code generation plugin turns natural-language prompts into tested, production-ready Python code.

You describe what the code should do, along with sample data, schema definitions, constraints, and typed inputs/outputs, and the plugin handles the rest: generating code, writing tests, building an isolated [code sandbox](/docs/v2/union//user-guide/sandboxing/code-sandboxing) with the right dependencies, running the tests, diagnosing failures, and iterating until everything passes. The result is a validated script you can execute against real data or deploy as a reusable Flyte task.

## Installation

```bash
pip install flyteplugins-codegen

# For Agent mode (Claude-only)
pip install flyteplugins-codegen[agent]
```

## Quick start

```python{hl_lines=[3, 4, 6, 12, 14, "20-25"]}
import flyte
from flyte.io import File
from flyte.sandbox import sandbox_environment
from flyteplugins.codegen import AutoCoderAgent

agent = AutoCoderAgent(model="gpt-4.1", name="summarize-sales")

env = flyte.TaskEnvironment(
    name="my-env",
    secrets=[flyte.Secret(key="openai_key", as_env_var="OPENAI_API_KEY")],
    image=flyte.Image.from_debian_base().with_pip_packages(
        "flyteplugins-codegen",
    ),
    depends_on=[sandbox_environment],
)

@env.task
async def process_data(csv_file: File) -> tuple[float, int, int]:
    result = await agent.generate.aio(
        prompt="Read the CSV and compute total_revenue, total_units and row_count.",
        samples={"sales": csv_file},
        outputs={"total_revenue": float, "total_units": int, "row_count": int},
    )
    return await result.run.aio()
```

The `depends_on=[sandbox_environment]` declaration is required. It ensures the sandbox runtime is available when dynamically-created sandboxes execute.

![Sandbox](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/integrations/codegen/sandbox.png)

## Two execution backends

The plugin supports two backends for generating and validating code. Both share the same `AutoCoderAgent` interface and produce the same `CodeGenEvalResult`.

### LiteLLM (default)

Uses structured-output LLM calls to generate code, detect packages, build sandbox images, run tests, diagnose failures, and iterate. Works with any model that supports structured outputs (GPT-4, Claude, Gemini, etc. via LiteLLM).

```python{hl_lines=[1, 3]}
agent = AutoCoderAgent(
    name="my-task",
    model="gpt-4.1",
    max_iterations=10,
)
```

The LiteLLM backend follows a fixed pipeline:

```mermaid
flowchart TD
    A["prompt + samples"] --> B["generate_plan"]
    B --> C["generate_code"]
    C --> D["detect_packages"]
    D --> E["build_image"]
    E --> F{skip_tests?}
    F -- yes --> G["return result"]
    F -- no --> H["generate_tests"]
    H --> I["execute_tests"]
    I --> J{pass?}
    J -- yes --> G
    J -- no --> K["diagnose_error"]
    K --> L{error type?}
    L -- "logic error" --> M["regenerate code"]
    L -- "environment error" --> N["add packages, rebuild image"]
    L -- "test error" --> O["fix test expectations"]
    M --> I
    N --> I
    O --> I
```

The loop continues until tests pass or `max_iterations` is reached.

![LiteLLM](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/integrations/codegen/litellm.png)

### Agent (Claude)

Uses the Claude Agent SDK to autonomously generate, test, and fix code. The agent has access to `Bash`, `Read`, `Write`, and `Edit` tools and decides what to do at each step. Test execution commands (`pytest`) are intercepted and run inside isolated sandboxes.

```python{hl_lines=["3-4"]}
agent = AutoCoderAgent(
    name="my-task",
    model="claude-sonnet-4-5-20250929",
    backend="claude",
)
```

> [!NOTE]
> Agent mode requires `ANTHROPIC_API_KEY` as a Flyte secret and is Claude-only.

**Key differences from LiteLLM:**

|                       | LiteLLM                           | Agent                                          |
| --------------------- | --------------------------------- | ---------------------------------------------- |
| **Execution**         | Fixed generate-test-fix pipeline  | Autonomous agent decides actions               |
| **Model support**     | Any model with structured outputs | Claude only                                    |
| **Iteration control** | `max_iterations`                  | `agent_max_turns`                              |
| **Test execution**    | Direct sandbox execution          | `pytest` commands intercepted via hooks        |
| **Tool safety**       | N/A                               | Commands classified as safe/denied/intercepted |
| **Observability**     | Logs + token counts               | Full tool call tracing in Flyte UI             |

In Agent mode, Bash commands are classified before execution:

- **Safe** (`ls`, `cat`, `grep`, `head`, etc.) — allowed to run directly
- **Intercepted** (`pytest`) — routed to sandbox execution
- **Denied** (`apt`, `pip install`, `curl`, etc.) — blocked for safety

## Providing data

### Sample data

Pass sample data via `samples` as `File` objects or pandas `DataFrame`s. The plugin automatically:

1. Converts DataFrames to CSV files
2. Infers [Pandera](https://pandera.readthedocs.io/) schemas from the data — column types, nullability
3. Parses natural-language `constraints` into Pandera checks (e.g., `"quantity must be positive"` becomes `pa.Check.gt(0)`)
4. Extracts data context — column statistics, distributions, patterns, sample rows
5. Injects all of this into the LLM prompt so the generated code is aware of the exact data structure

Pandera is used purely for prompt enrichment, not runtime validation. The generated code does not import Pandera — it benefits from the LLM knowing the precise data structure. The generated schemas are stored on `result.generated_schemas` for inspection.

```python{hl_lines=[3]}
result = await agent.generate.aio(
    prompt="Clean and validate the data, remove duplicates",
    samples={"orders": orders_df, "products": products_file},
    constraints=["quantity must be positive", "price between 0 and 10000"],
    outputs={"cleaned_orders": File},
)
```

### Schema and constraints

Use `schema` to provide free-form context about data formats or target structures (e.g., a database schema). Use `constraints` to declare business rules that the generated code must respect:

```python{hl_lines=["4-17"]}
result = await agent.generate.aio(
    prompt=prompt,
    samples={"readings": sensor_df},
    schema="""Output JSON schema for report_json:
    {
        "sensor_id": str,
        "avg_temp": float,
        "min_temp": float,
        "max_temp": float,
        "avg_humidity": float,
    }
    """,
    constraints=[
        "Temperature values must be between -40 and 60 Celsius",
        "Humidity values must be between 0 and 100 percent",
        "Output report must have one row per unique sensor_id",
    ],
    outputs={
        "report_json": str,
        "total_anomalies": int,
    },
)
```

![Pandera Constraints](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/integrations/codegen/pandera_constraints.png)

### Inputs and outputs

Declare `inputs` for non-sample arguments (e.g., thresholds, flags) and `outputs` for the expected result types.

Supported output types: `str`, `int`, `float`, `bool`, `datetime.datetime`, `datetime.timedelta`, `File`.

Sample entries are automatically added as `File` inputs — you do not need to redeclare them.

```python{hl_lines=[4, 5]}
result = await agent.generate.aio(
    prompt="Filter transactions above the threshold",
    samples={"transactions": tx_file},
    inputs={"threshold": float, "include_pending": bool},
    outputs={"filtered": File, "count": int},
)
```

## Running generated code

`agent.generate()` returns a `CodeGenEvalResult`. If `result.success` is `True`, the generated code passed all tests and you can execute it against real data. If `max_iterations` (LiteLLM) or `agent_max_turns` (Agent) is reached without tests passing, `result.success` is `False` and `result.error` contains the failure details.

Both `run()` and `as_task()` return output values as a tuple in the order declared in `outputs`. If there is a single output, the value is returned directly (not wrapped in a tuple).

### One-shot execution with `result.run()`

Runs the generated code in a sandbox. If samples were provided during `generate()`, they are used as default inputs.

```python
# Use sample data as defaults
total_revenue, total_units, count = await result.run.aio()

# Override specific inputs
total_revenue, total_units, count = await result.run.aio(threshold=0.5)

# Sync version
total_revenue, total_units, count = result.run()
```

`result.run()` accepts optional configuration:

```python{hl_lines=["4-6"]}
total_revenue, total_units, count = await result.run.aio(
    name="execute-on-data",
    resources=flyte.Resources(cpu=2, memory="4Gi"),
    retries=2,
    timeout=600,
    cache="auto",
)
```

### Reusable task with `result.as_task()`

Creates a callable sandbox task from the generated code. Useful when you want to run the same generated code against different data.

```python{hl_lines=[1, "6-7", "9-10"]}
task = result.as_task(
    name="run-sensor-analysis",
    resources=flyte.Resources(cpu=1, memory="512Mi"),
)

# Call with sample defaults
report, total_anomalies = await task.aio()

# Call with different data
report, total_anomalies = await task.aio(readings=new_data_file)
```

## Error diagnosis

The LiteLLM backend classifies test failures into three categories and applies targeted fixes:

| Error type    | Meaning                       | Action                                           |
| ------------- | ----------------------------- | ------------------------------------------------ |
| `logic`       | Bug in the generated code     | Regenerate code with specific patch instructions |
| `environment` | Missing package or dependency | Add the package and rebuild the sandbox image    |
| `test_error`  | Bug in the generated test     | Fix the test expectations                        |

If the same error persists after a fix, the plugin reclassifies it (e.g., `logic` to `test_error`) to try the other approach.

In Agent mode, the agent diagnoses and fixes issues autonomously based on error output.

## Durable execution

Code generation is expensive — it involves multiple LLM calls, image builds, and sandbox executions. Without durability, a transient failure in the pipeline (network blip, OOM, downstream service error) would force the entire process to restart from scratch: regenerating code, rebuilding images, re-running sandboxes, making additional LLM calls.

Flyte solves this through two complementary mechanisms: **replay logs** and **caching**.

### Replay logs

Flyte maintains a replay log that records every trace and task execution within a run. When a task crashes and retries, the system replays the log from the previous attempt rather than recomputing everything:

- No additional model calls
- No code regeneration
- No sandbox re-execution
- No container rebuilds

The workflow breezes through the earlier steps and resumes from the failure point. This applies as long as the traces and tasks execute in the same order and use the same inputs as the first attempt.

### Caching

Separately, Flyte can cache task results across runs. With `cache="auto"`, sandbox executions (image builds, test runs, code execution) are cached. This is useful when you re-run the same pipeline — not just when recovering from a crash, but across entirely separate invocations with the same inputs.

Together, replay logs handle crash recovery within a run, and caching avoids redundant work across runs.

### Non-determinism in Agent mode

One challenge with agents is that they are inherently non-deterministic — the sequence of actions can vary between runs, which could break replay.

In practice, the codegen agent follows a predictable pattern (write code, generate tests, run tests, inspect results), which works in replay's favor. The plugin also embeds logic that instructs the agent not to regenerate or re-execute steps that already completed successfully in the first run. This acts as an additional safety check alongside the replay log to account for non-determinism.

![Agent](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/integrations/codegen/agent.png)

On the first attempt, the full pipeline runs. If a transient failure occurs, the system instantly replays the traces (which track model calls) and sandbox executions, allowing the pipeline to resume from the point of failure.

![Durability](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/integrations/codegen/durability.png)

## Observability

### LiteLLM backend

- Logs every iteration with attempt count, error type, and package changes
- Tracks total input/output tokens across all LLM calls (available on `result.total_input_tokens` and `result.total_output_tokens`)
- Results include full conversation history for debugging (`result.conversation_history`)

### Agent backend

- Traces each tool call (name + input) via `PostToolUse` hooks
- Traces tool failures via `PostToolUseFailure` hooks
- Traces a summary when the agent finishes (total tool calls, tool distribution, final image/packages)
- Classifies Bash commands as safe, denied, or intercepted (for sandbox execution)
- All traces appear in the Flyte UI

## Examples

### Processing CSVs with different schemas

Generate code that handles varying CSV formats, then run on real data:

```python{hl_lines=[1, 3, 14, 16, 27]}
from flyteplugins.codegen import AutoCoderAgent

agent = AutoCoderAgent(
    name="sales-processor",
    model="gpt-4.1",
    max_iterations=5,
    resources=flyte.Resources(cpu=1, memory="512Mi"),
    litellm_params={"temperature": 0.2, "max_tokens": 4096},
)

@env.task
async def process_sales(csv_file: File) -> dict[str, float | int]:
    result = await agent.generate.aio(
        prompt="Read the CSV and compute total_revenue, total_units, and transaction_count.",
        samples={"csv_data": csv_file},
        outputs={
            "total_revenue": float,
            "total_units": int,
            "transaction_count": int,
        },
    )

    if not result.success:
        raise RuntimeError(f"Code generation failed: {result.error}")

    total_revenue, total_units, transaction_count = await result.run.aio()

    return {
        "total_revenue": total_revenue,
        "total_units": total_units,
        "transaction_count": transaction_count,
    }
```

### DataFrame analysis with constraints

Pass DataFrames directly and enforce business rules with constraints:

```python{hl_lines=[10, "15-19"]}
agent = AutoCoderAgent(
    model="gpt-4.1",
    name="sensor-analysis",
    base_packages=["numpy"],
    max_sample_rows=30,
)

@env.task
async def analyze_sensors(sensor_df: pd.DataFrame) -> tuple[File, int]:
    result = await agent.generate.aio(
        prompt="""Analyze IoT sensor data. For each sensor, calculate mean/min/max
temperature, mean humidity, and count warnings. Output a summary CSV.""",
        samples={"readings": sensor_df},
        constraints=[
            "Temperature values must be between -40 and 60 Celsius",
            "Humidity values must be between 0 and 100 percent",
            "Output report must have one row per unique sensor_id",
        ],
        outputs={
            "report": File,
            "total_anomalies": int,
        },
    )

    if not result.success:
        raise RuntimeError(f"Code generation failed: {result.error}")

    task = result.as_task(
        name="run-sensor-analysis",
        resources=flyte.Resources(cpu=1, memory="512Mi"),
    )

    return await task.aio(readings=result.original_samples["readings"])
```

### Agent mode

The same task using Claude as an autonomous agent:

```python{hl_lines=[3]}
agent = AutoCoderAgent(
    name="sales-agent",
    backend="claude",
    model="claude-sonnet-4-5-20250929",
    resources=flyte.Resources(cpu=1, memory="512Mi"),
)

@env.task
async def process_sales_with_agent(csv_file: File) -> dict[str, float | int]:
    result = await agent.generate.aio(
        prompt="Read the CSV and compute total_revenue, total_units, and transaction_count.",
        samples={"csv_data": csv_file},
        outputs={
            "total_revenue": float,
            "total_units": int,
            "transaction_count": int,
        },
    )

    if not result.success:
        raise RuntimeError(f"Agent code generation failed: {result.error}")

    total_revenue, total_units, transaction_count = await result.run.aio()

    return {
        "total_revenue": total_revenue,
        "total_units": total_units,
        "transaction_count": transaction_count,
    }
```

## Configuration

### LiteLLM parameters

Tune model behavior with `litellm_params`:

```python{hl_lines=["5-8"]}
agent = AutoCoderAgent(
    name="my-task",
    model="anthropic/claude-sonnet-4-20250514",
    api_key="ANTHROPIC_API_KEY",
    litellm_params={
        "temperature": 0.3,
        "max_tokens": 4000,
    },
)
```

### Image configuration

Control the registry and Python version for sandbox images:

```python{hl_lines=["6-10"]}
from flyte.sandbox import ImageConfig

agent = AutoCoderAgent(
    name="my-task",
    model="gpt-4.1",
    image_config=ImageConfig(
        registry="my-registry.io",
        registry_secret="registry-creds",
        python_version=(3, 12),
    ),
)
```

### Skipping tests

Set `skip_tests=True` to skip test generation and execution. The agent still generates code, detects packages, and builds the sandbox image, but does not generate or run tests.

```python{hl_lines=[4]}
agent = AutoCoderAgent(
    name="my-task",
    model="gpt-4.1",
    skip_tests=True,
)
```

> [!NOTE]
> `skip_tests` only applies to LiteLLM mode. In Agent mode, the agent autonomously decides when to test.

### Base packages

Ensure specific packages are always installed in every sandbox:

```python{hl_lines=[4]}
agent = AutoCoderAgent(
    name="my-task",
    model="gpt-4.1",
    base_packages=["numpy", "pandas"],
)
```

## Best practices

- **One agent per task.** Each `generate()` call builds its own sandbox image and manages its own package state. Running multiple agents in the same task can cause resource contention and makes failures harder to diagnose.
- **Keep `cache="auto"` (the default).** Caching flows to all internal sandboxes, making retries near-instant. Use `"disable"` during development if you want fresh executions, or `"override"` to force re-execution and update the cached result.
- **Set `max_iterations` conservatively.** Start with 5-10 iterations. If the model cannot produce correct code in that budget, the prompt or constraints likely need refinement.
- **Provide constraints for data-heavy tasks.** Explicit constraints (e.g., `"quantity must be positive"`) produce better schemas and better generated code.
- **Inspect `result.generated_schemas`.** Review the inferred Pandera schemas to verify the model understood your data structure correctly.

## API reference

### `AutoCoderAgent` constructor

| Parameter         | Type              | Default        | Description                                                                            |
| ----------------- | ----------------- | -------------- | -------------------------------------------------------------------------------------- |
| `name`            | `str`             | `"auto-coder"` | Unique name for tracking and image naming                                              |
| `model`           | `str`             | `"gpt-4.1"`    | LiteLLM model identifier                                                               |
| `backend`         | `str`             | `"litellm"`    | Execution backend: `"litellm"` or `"claude"`                                           |
| `system_prompt`   | `str`             | `None`         | Custom system prompt override                                                          |
| `api_key`         | `str`             | `None`         | Name of the environment variable containing the LLM API key (e.g., `"OPENAI_API_KEY"`) |
| `api_base`        | `str`             | `None`         | Custom API base URL                                                                    |
| `litellm_params`  | `dict`            | `None`         | Extra LiteLLM params (temperature, max_tokens, etc.)                                   |
| `base_packages`   | `list[str]`       | `None`         | Always-install pip packages                                                            |
| `resources`       | `flyte.Resources` | `None`         | Resources for sandbox execution (default: 1 CPU, 1Gi)                                  |
| `image_config`    | `ImageConfig`     | `None`         | Registry, secret, and Python version                                                   |
| `max_iterations`  | `int`             | `10`           | Max generate-test-fix iterations (LiteLLM mode)                                        |
| `max_sample_rows` | `int`             | `100`          | Rows to sample from data for LLM context                                               |
| `skip_tests`      | `bool`            | `False`        | Skip test generation and execution (LiteLLM mode)                                      |
| `sandbox_retries` | `int`             | `0`            | Flyte task-level retries for each sandbox execution                                    |
| `timeout`         | `int`             | `None`         | Timeout in seconds for sandboxes                                                       |
| `env_vars`        | `dict[str, str]`  | `None`         | Environment variables for sandboxes                                                    |
| `secrets`         | `list[Secret]`    | `None`         | Flyte secrets for sandboxes                                                            |
| `cache`           | `str`             | `"auto"`       | Cache behavior: `"auto"`, `"override"`, or `"disable"`                                 |
| `agent_max_turns` | `int`             | `50`           | Max turns when `backend="claude"`                                                      |

### `generate()` parameters

| Parameter     | Type                           | Default  | Description                                                                             |
| ------------- | ------------------------------ | -------- | --------------------------------------------------------------------------------------- |
| `prompt`      | `str`                          | required | Natural-language task description                                                       |
| `schema`      | `str`                          | `None`   | Free-form context about data formats or target structures                               |
| `constraints` | `list[str]`                    | `None`   | Natural-language constraints (e.g., `"quantity must be positive"`)                      |
| `samples`     | `dict[str, File \| DataFrame]` | `None`   | Sample data. DataFrames are auto-converted to CSV files.                                |
| `inputs`      | `dict[str, type]`              | `None`   | Non-sample input types (e.g., `{"threshold": float}`)                                   |
| `outputs`     | `dict[str, type]`              | `None`   | Output types. Supported: `str`, `int`, `float`, `bool`, `datetime`, `timedelta`, `File` |

### `CodeGenEvalResult` fields

| Field                      | Type                      | Description                                               |
| -------------------------- | ------------------------- | --------------------------------------------------------- |
| `success`                  | `bool`                    | Whether tests passed                                      |
| `solution`                 | `CodeSolution`            | Generated code (`.code`, `.language`, `.system_packages`) |
| `tests`                    | `str`                     | Generated test code                                       |
| `output`                   | `str`                     | Test output                                               |
| `exit_code`                | `int`                     | Test exit code                                            |
| `error`                    | `str \| None`             | Error message if failed                                   |
| `attempts`                 | `int`                     | Number of iterations used                                 |
| `image`                    | `str`                     | Built sandbox image with all dependencies                 |
| `detected_packages`        | `list[str]`               | Pip packages detected                                     |
| `detected_system_packages` | `list[str]`               | Apt packages detected                                     |
| `generated_schemas`        | `dict[str, str] \| None`  | Pandera schemas as Python code strings                    |
| `data_context`             | `str \| None`             | Extracted data context                                    |
| `original_samples`         | `dict[str, File] \| None` | Sample data as Files (defaults for `run()`/`as_task()`)   |
| `total_input_tokens`       | `int`                     | Total input tokens across all LLM calls                   |
| `total_output_tokens`      | `int`                     | Total output tokens across all LLM calls                  |
| `conversation_history`     | `list[dict]`              | Full LLM conversation history for debugging               |

### `CodeGenEvalResult` methods

| Method                              | Description                                                        |
| ----------------------------------- | ------------------------------------------------------------------ |
| `result.run(**overrides)`           | Execute generated code in a sandbox. Sample data used as defaults. |
| `await result.run.aio(**overrides)` | Async version of `run()`.                                          |
| `result.as_task(name, ...)`         | Create a reusable callable sandbox task from the generated code.   |

Both `run()` and `as_task()` accept optional `name`, `resources`, `retries`, `timeout`, `env_vars`, `secrets`, and `cache` parameters.

=== PAGE: https://www.union.ai/docs/v2/union/integrations/dask ===

# Dask

The Dask plugin lets you run [Dask](https://www.dask.org/) jobs natively on Kubernetes. Flyte provisions a transient Dask cluster for each task execution using the [Dask Kubernetes Operator](https://kubernetes.dask.org/en/latest/operator.html) and tears it down on completion.

## When to use this plugin

- Parallel Python workloads that outgrow a single machine
- Distributed DataFrame operations on large datasets
- Workloads that use Dask's task scheduler for arbitrary computation graphs
- Jobs that need to scale NumPy, pandas, or scikit-learn workflows across multiple nodes

## Installation

```bash
pip install flyteplugins-dask
```

Your task image must also include the Dask distributed scheduler:

```python
image = flyte.Image.from_debian_base(name="dask").with_pip_packages("flyteplugins-dask")
```

## Configuration

Create a `Dask` configuration and pass it as `plugin_config` to a `TaskEnvironment`:

```python
from flyteplugins.dask import Dask, Scheduler, WorkerGroup

dask_config = Dask(
    scheduler=Scheduler(),
    workers=WorkerGroup(number_of_workers=4),
)

dask_env = flyte.TaskEnvironment(
    name="dask_env",
    plugin_config=dask_config,
    image=image,
)
```

### `Dask` parameters

| Parameter | Type | Description |
|-----------|------|-------------|
| `scheduler` | `Scheduler` | Scheduler pod configuration (defaults to `Scheduler()`) |
| `workers` | `WorkerGroup` | Worker group configuration (defaults to `WorkerGroup()`) |

### `Scheduler` parameters

| Parameter | Type | Description |
|-----------|------|-------------|
| `image` | `str` | Custom scheduler image (must include `dask[distributed]`) |
| `resources` | `Resources` | Resource requests for the scheduler pod |

### `WorkerGroup` parameters

| Parameter | Type | Description |
|-----------|------|-------------|
| `number_of_workers` | `int` | Number of worker pods (default: `1`) |
| `image` | `str` | Custom worker image (must include `dask[distributed]`) |
| `resources` | `Resources` | Resource requests per worker pod |

> [!NOTE]
> The scheduler and all workers should use the same Python environment to avoid serialization issues.

### Accessing the Dask client

Inside a Dask task, create a `distributed.Client()` with no arguments. It automatically connects to the provisioned cluster:

```python
from distributed import Client

@dask_env.task
async def my_dask_task(n: int) -> list:
    client = Client()
    futures = client.map(lambda x: x + 1, range(n))
    return client.gather(futures)
```

## Example

```python
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "flyteplugins-dask",
#    "distributed"
# ]
# main = "hello_dask_nested"
# params = ""
# ///

import asyncio
import typing

from distributed import Client
from flyteplugins.dask import Dask, Scheduler, WorkerGroup

import flyte.remote
import flyte.storage
from flyte import Resources

image = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages("flyteplugins-dask")

dask_config = Dask(
    scheduler=Scheduler(),
    workers=WorkerGroup(number_of_workers=4),
)

task_env = flyte.TaskEnvironment(
    name="hello_dask", resources=Resources(cpu=(1, 2), memory=("400Mi", "1000Mi")), image=image
)
dask_env = flyte.TaskEnvironment(
    name="dask_env",
    plugin_config=dask_config,
    image=image,
    resources=Resources(cpu="1", memory="1Gi"),
    depends_on=[task_env],
)

@task_env.task()
async def hello_dask():
    await asyncio.sleep(5)
    print("Hello from the Dask task!")

@dask_env.task
async def hello_dask_nested(n: int = 3) -> typing.List[int]:
    print("running dask task")
    t = asyncio.create_task(hello_dask())
    client = Client()
    futures = client.map(lambda x: x + 1, range(n))
    res = client.gather(futures)
    await t
    return res

if __name__ == "__main__":
    flyte.init_from_config()
    r = flyte.run(hello_dask_nested)
    print(r.name)
    print(r.url)
    r.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/dask/dask_example.py*

## API reference

See the [Dask API reference](https://www.union.ai/docs/v2/union/api-reference/integrations/dask/_index) for full details.

=== PAGE: https://www.union.ai/docs/v2/union/integrations/databricks ===

# Databricks

The Databricks plugin lets you run PySpark jobs on [Databricks](https://www.databricks.com/) clusters directly from Flyte tasks. You write normal PySpark code in a Flyte task, and the plugin submits it to Databricks via the [Jobs API 2.1](https://docs.databricks.com/api/workspace/jobs/submit). The connector handles job submission, polling, and cancellation.

The plugin supports:

- Running PySpark tasks on new or existing Databricks clusters
- Full Spark configuration (driver/executor memory, cores, instances)
- Databricks cluster auto-scaling
- API token-based authentication

## Installation

```bash
pip install flyteplugins-databricks
```

This also installs `flyteplugins-spark` as a dependency, since the Databricks plugin extends the Spark plugin.

## Quick start

Create a `Databricks` configuration and pass it as `plugin_config` to a `TaskEnvironment`:

```python
from flyteplugins.databricks import Databricks
import flyte

image = (
    flyte.Image.from_base("databricksruntime/standard:16.4-LTS")
    .clone(name="spark", registry="ghcr.io/flyteorg", extendable=True)
    .with_env_vars({"UV_PYTHON": "/databricks/python3/bin/python"})
    .with_pip_packages("flyteplugins-databricks", pre=True)
)

databricks_conf = Databricks(
    spark_conf={
        "spark.driver.memory": "2000M",
        "spark.executor.memory": "1000M",
        "spark.executor.cores": "1",
        "spark.executor.instances": "2",
        "spark.driver.cores": "1",
    },
    executor_path="/databricks/python3/bin/python",
    databricks_conf={
        "run_name": "flyte databricks plugin",
        "new_cluster": {
            "spark_version": "13.3.x-scala2.12",
            "node_type_id": "m6i.large",
            "autoscale": {"min_workers": 1, "max_workers": 2},
        },
        "timeout_seconds": 3600,
        "max_retries": 1,
    },
    databricks_instance="myaccount.cloud.databricks.com",
    databricks_token="DATABRICKS_TOKEN",
)

databricks_env = flyte.TaskEnvironment(
    name="databricks_env",
    resources=flyte.Resources(cpu=(1, 2), memory=("3000Mi", "5000Mi")),
    plugin_config=databricks_conf,
    image=image,
)
```

Then use the environment to decorate your task:

```python
@databricks_env.task
async def hello_databricks() -> float:
    spark = flyte.ctx().data["spark_session"]
    # Use spark as a normal SparkSession
    count = spark.sparkContext.parallelize(range(100)).count()
    return float(count)
```

## Configuration

The `Databricks` config extends the [Spark](../spark/_index) config with Databricks-specific fields.

### Spark fields (inherited)

| Parameter | Type | Description |
|-----------|------|-------------|
| `spark_conf` | `Dict[str, str]` | Spark configuration key-value pairs |
| `hadoop_conf` | `Dict[str, str]` | Hadoop configuration key-value pairs |
| `executor_path` | `str` | Path to the Python binary on the Databricks cluster (e.g., `/databricks/python3/bin/python`) |
| `applications_path` | `str` | Path to the main application file |

### Databricks-specific fields

| Parameter | Type | Description |
|-----------|------|-------------|
| `databricks_conf` | `Dict[str, Union[str, dict]]` | Databricks [run-submit](https://docs.databricks.com/api/workspace/jobs/submit) job configuration. Must contain either `existing_cluster_id` or `new_cluster` |
| `databricks_instance` | `str` | Your workspace domain (e.g., `myaccount.cloud.databricks.com`). Can also be set via the `FLYTE_DATABRICKS_INSTANCE` env var on the connector |
| `databricks_token` | `str` | Name of the Flyte secret containing the Databricks API token |

### `databricks_conf` structure

The `databricks_conf` dict maps to the Databricks run-submit API payload. Key fields:

| Field | Description |
|-------|-------------|
| `new_cluster` | Cluster spec with `spark_version`, `node_type_id`, `autoscale`, etc. |
| `existing_cluster_id` | ID of an existing cluster to use instead of creating a new one |
| `run_name` | Display name in the Databricks UI |
| `timeout_seconds` | Maximum job duration |
| `max_retries` | Number of retries before marking the job as failed |

The connector automatically injects the Docker image, Spark configuration, and environment variables from the task container into the cluster spec.

## Authentication

Store your Databricks API token as a Flyte secret. The `databricks_token` parameter specifies the secret name:

```python
databricks_conf = Databricks(
    # ...
    databricks_token="DATABRICKS_TOKEN",
)
```

## Accessing the Spark session

Inside a Databricks task, the `SparkSession` is available through the task context, just like the [Spark plugin](../spark/_index):

```python
@databricks_env.task
async def my_databricks_task() -> float:
    spark = flyte.ctx().data["spark_session"]
    df = spark.read.parquet("s3://my-bucket/data.parquet")
    return float(df.count())
```

## API reference

See the [Databricks API reference](https://www.union.ai/docs/v2/union/api-reference/integrations/databricks/_index) for full details.

=== PAGE: https://www.union.ai/docs/v2/union/integrations/gemini ===

# Gemini

The Gemini plugin lets you build agentic workflows with [Gemini](https://ai.google.dev/) on Flyte. It provides a `function_tool` decorator that wraps Flyte tasks as tools that Gemini can call, and a `run_agent` function that drives the agent conversation loop.

When Gemini calls a tool, the call executes as a Flyte task with full observability, retries, and caching. Gemini's native parallel function calling is supported: multiple tool calls in a single turn are all dispatched and their results bundled into one response.

## Installation

```bash
pip install flyteplugins-gemini
```

Requires `google-genai >= 1.0.0`.

## Quick start

```python
import flyte
from flyteplugins.gemini import function_tool, run_agent

env = flyte.TaskEnvironment(
    name="gemini-agent",
    resources=flyte.Resources(cpu=1, memory="250Mi"),
    image=flyte.Image.from_uv_script(__file__, name="gemini_agent"),
    secrets=flyte.Secret("google_api_key", as_env_var="GOOGLE_API_KEY"),
)

@function_tool
@env.task
async def get_weather(city: str) -> str:
    """Get the current weather for a city."""
    return f"The weather in {city} is sunny, 72F"

@env.task
async def main(prompt: str) -> str:
    tools = [get_weather]
    return await run_agent(prompt=prompt, tools=tools)
```

## API

### `function_tool`

Converts a Flyte task, `@flyte.trace`-decorated function, or plain callable into a tool that Gemini can invoke.

```python
@function_tool
@env.task
async def my_tool(param: str) -> str:
    """Tool description sent to Gemini."""
    ...
```

Can also be called with optional overrides:

```python
@function_tool(name="custom_name", description="Custom description")
@env.task
async def my_tool(param: str) -> str:
    ...
```

Parameters:

| Parameter | Type | Description |
|-----------|------|-------------|
| `func` | callable | The function to wrap |
| `name` | `str` | Override the tool name (defaults to the function name) |
| `description` | `str` | Override the tool description (defaults to the docstring) |

> [!NOTE]
> The docstring on each `@function_tool` task is sent to Gemini as the tool description. Write clear, concise docstrings.

### `Agent`

A dataclass for bundling agent configuration:

```python
from flyteplugins.gemini import Agent

agent = Agent(
    name="my-agent",
    instructions="You are a helpful assistant.",
    model="gemini-2.5-flash",
    tools=[get_weather],
    max_output_tokens=8192,
    max_iterations=10,
)
```

| Field | Type | Default | Description |
|-------|------|---------|-------------|
| `name` | `str` | `"assistant"` | Agent name |
| `instructions` | `str` | `"You are a helpful assistant."` | System prompt |
| `model` | `str` | `"gemini-2.5-flash"` | Gemini model ID |
| `tools` | `list[FunctionTool]` | `[]` | Tools available to the agent |
| `max_output_tokens` | `int` | `8192` | Maximum tokens per response |
| `max_iterations` | `int` | `10` | Maximum tool-call loop iterations |

### `run_agent`

Runs a Gemini conversation loop, dispatching tool calls to Flyte tasks until Gemini returns a final response.

```python
result = await run_agent(
    prompt="What's the weather in Tokyo?",
    tools=[get_weather],
    model="gemini-2.5-flash",
)
```

You can also pass an `Agent` object:

```python
result = await run_agent(prompt="What's the weather?", agent=agent)
```

| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `prompt` | `str` | required | User message |
| `tools` | `list[FunctionTool]` | `None` | Tools available to the agent |
| `agent` | `Agent` | `None` | Agent config (overrides individual params) |
| `model` | `str` | `"gemini-2.5-flash"` | Gemini model ID |
| `system` | `str` | `None` | System prompt |
| `max_output_tokens` | `int` | `8192` | Maximum tokens per response |
| `max_iterations` | `int` | `10` | Maximum iterations (prevents infinite loops) |
| `api_key` | `str` | `None` | API key (falls back to `GOOGLE_API_KEY` env var) |

## Secrets

Store your Google API key as a Flyte secret and expose it as an environment variable:

```python
secrets=flyte.Secret("google_api_key", as_env_var="GOOGLE_API_KEY")
```

## API reference

See the [Gemini API reference](https://www.union.ai/docs/v2/union/api-reference/integrations/gemini/_index) for full details.

=== PAGE: https://www.union.ai/docs/v2/union/integrations/hydra ===

# Hydra

[Hydra](https://hydra.cc) is a framework for composing and overriding configuration trees from YAML files, dataclasses and the command line. The `flyteplugins-hydra` plugin makes Hydra a first-class submission layer for Flyte, so you can compose a config exactly as you would in any other Hydra app and have each composed run executed as a Flyte task, locally or as a remote execution on a Union.ai cluster.

The plugin offers three complementary entry points that share a single launcher implementation:

| Entry point                                    | Use it when                                                                                                                              |
| ---------------------------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------- |
| `hydra/launcher=flyte` (Hydra Launcher plugin) | You already have a `@hydra.main` script and want standard Hydra CLI ergonomics, including `--multirun` and custom sweepers.              |
| `flyte hydra run` (Flyte CLI extension)        | You want a Flyte-style CLI that imports a task from a Python file and composes a Hydra config without requiring a `@hydra.main` wrapper. |
| `hydra_run` / `hydra_sweep` (Python SDK)       | You want to submit runs directly from Python -- notebooks, tests, examples or another orchestration script.                              |

All three paths converge on the same `FlyteLauncher`.

## Installation

```bash
pip install flyteplugins-hydra
```

The plugin depends on `flyteplugins-omegaconf`, which is installed automatically and provides the `DictConfig`/`ListConfig` type transformers that allow Hydra-composed configs to flow into Flyte tasks. Both packages must be available in the same environment as `flyte`.

If you call `apply_task_env` for child tasks (see **Hydra > Task environment overrides**), include `flyteplugins-hydra` in the task image as well.

## Requirements on tasks

Every task launched through this plugin must accept an OmegaConf `DictConfig` input. Any other parameters are passed through as ordinary task arguments.

```python{hl_lines=[1, 5]}
from omegaconf import DictConfig

@env.task
async def pipeline(cfg: DictConfig, dataset: str) -> float:
    ...
```

The plugin auto-detects the `DictConfig` parameter name. If your parameter is `cfg`, app-level overrides are passed through `--cfg` on the CLI; if it is `config`, they are passed through `--config`; and so on.

## A walkthrough config

The examples in this page assume a small project layout:

```
project/
├── train.py
└── conf/
    ├── training.yaml
    ├── model/
    │   ├── resnet.yaml
    │   └── vit.yaml
    ├── optimizer/
    │   ├── adam.yaml
    │   └── sgd.yaml
    └── task_env/
        ├── a100.yaml
        └── prebuilt_image.yaml
```

`conf/training.yaml`:

```yaml
defaults:
  - optimizer: adam
  - model: resnet
  - _self_

data:
  path: s3://my-bucket/imagenet
  dataset: imagenet

training:
  epochs: 30
  batch_size: 64
```

`train.py` (abbreviated):

```python
import flyte
from omegaconf import DictConfig
from flyteplugins.hydra import apply_task_env

env = flyte.TaskEnvironment(name="training", image=...)

@env.task
async def preprocess(cfg: DictConfig) -> flyte.io.Dir: ...

@env.task
async def train_model(cfg: DictConfig, data: flyte.io.Dir) -> tuple[flyte.io.Dir, float]: ...

@env.task
async def pipeline(cfg: DictConfig, dataset: str) -> float:
    data = await preprocess(cfg)
    train_task = apply_task_env(train_model, cfg)
    _, val_loss = await train_task(cfg, data)
    return val_loss
```

The same `pipeline` task is the target of every example below.

> **📝 Note**
>
> `config_path` is resolved relative to the current working directory. If you submit runs from a directory other than `project/`, pass an absolute path (or an absolute path on the CLI via `--config-path /abs/path/to/conf`). For structured-config-only setups (no YAML files), omit `config_path` / `--config-path` entirely.

## Execution mode

Remote execution is the default. Every entry point exposes an explicit knob:

| Surface                | Local                       | Remote                                 |
| ---------------------- | --------------------------- | -------------------------------------- |
| `@hydra.main` launcher | `hydra.launcher.mode=local` | `hydra.launcher.mode=remote` (default) |
| `flyte hydra run`      | `--local`                   | `--mode remote` (default)              |
| Python SDK             | `mode="local"`              | `mode="remote"` (default)              |

For the `@hydra.main` launcher, the default applies as soon as `hydra/launcher=flyte` is selected.

Remote runs print the Flyte run URL immediately after submission, before any waiting. By default the plugin then waits for every submitted run to reach a terminal phase, capped at 32 worker threads. To tune or disable waiting:

| Surface                | Tune wait threads                    | Fire and forget             |
| ---------------------- | ------------------------------------ | --------------------------- |
| `@hydra.main` launcher | `hydra.launcher.wait_max_workers=64` | `hydra.launcher.wait=false` |
| `flyte hydra run`      | `--wait-max-workers 64`              | `--no-wait`                 |
| Python SDK             | `wait_max_workers=64`                | `wait=False`                |

For a sweep, every job is submitted first, and then the plugin waits for all runs concurrently. Submission is not blocked by earlier runs reaching a terminal phase.

## Hydra launcher (`@hydra.main` scripts)

Use this path when your script already has a `@hydra.main` entry point. Selecting `hydra/launcher=flyte` swaps Hydra's built-in `BasicLauncher` for `FlyteLauncher`.

Single remote run:

```bash
python train.py hydra/launcher=flyte hydra.launcher.mode=remote
```

Single local run:

```bash
python train.py hydra/launcher=flyte hydra.launcher.mode=local
```

Remote grid sweep submission: Each comma-separated value expands into a separate Flyte execution; six executions in this example:

```bash{hl_lines=[4]}
python train.py --multirun \
  hydra/launcher=flyte hydra.launcher.mode=remote \
  hydra.launcher.wait_max_workers=64 \
  optimizer.lr=0.001,0.01,0.1 training.epochs=10,20
```

Fire-and-forget sweep submission:

```bash{hl_lines=[2]}
python train.py --multirun \
  hydra/launcher=flyte hydra.launcher.wait=false \
  optimizer.lr=0.001,0.01,0.1
```

Custom sweepers (Optuna) work exactly as they do with the BasicLauncher. Selecting `hydra/sweeper=...` activates the sweeper and `FlyteLauncher` runs each trial as a Flyte execution:

```bash{hl_lines=["3-5"]}
python train.py --multirun \
  hydra/launcher=flyte hydra.launcher.mode=remote \
  hydra/sweeper=optuna hydra.sweeper.n_trials=20 \
  hydra.sweeper.n_jobs=4 \
  "optimizer.lr=interval(1e-4,1e-1)"
```

Inside `@hydra.main`, the standard pattern is:

```python{hl_lines=[7]}
import flyte
import hydra
from omegaconf import DictConfig
from flyteplugins.hydra import apply_task_env

@hydra.main(version_base=None, config_path="conf", config_name="training")
def main(cfg: DictConfig):
    flyte.init_from_config()
    entry_task = apply_task_env(pipeline, cfg)
    return flyte.run(entry_task, cfg=cfg, dataset=cfg.data.dataset)

if __name__ == "__main__":
    main()
```

## Python SDK

`hydra_run` composes one config and runs the task once. `hydra_sweep` expands sweep overrides and runs the task once per combination.

### Single run

```python{hl_lines=[1, 3, 7]}
from flyteplugins.hydra import hydra_run

run = hydra_run(
    pipeline,
    config_path="conf",
    config_name="training",
    overrides=["optimizer.lr=0.01"],
    dataset="s3://my-bucket/imagenet",
    mode="remote",
    wait=True,
    wait_max_workers=64,
)
```

For a remote run with `wait=True`, the return value is a wrapper exposing both `run.url` and `run.value` (the resolved task output). The wrapper is `float()`-castable so Hydra sweepers such as Optuna can consume scalar objectives directly. With `wait=False`, the return value is the underlying `flyte.remote.Run`.

### Grid sweep

```python{hl_lines=[7]}
from flyteplugins.hydra import hydra_sweep

runs = hydra_sweep(
    pipeline,
    config_path="conf",
    config_name="training",
    overrides=["optimizer.lr=0.001,0.01,0.1", "training.epochs=10,20"],
    dataset="s3://my-bucket/imagenet",
    mode="remote",
)
```

Six executions are submitted (3 × 2). `runs` is a list aligned with the Cartesian-product order Hydra's `BasicSweeper` produces.

### Custom sweepers

Custom sweeper plugins are activated by passing their selection in `overrides`:

```python{hl_lines=["5-10"]}
runs = hydra_sweep(
    pipeline,
    config_path="conf",
    config_name="training",
    overrides=[
        "hydra/sweeper=optuna",
        "hydra.sweeper.n_trials=20",
        "hydra.sweeper.n_jobs=4",
        "optimizer.lr=interval(1e-4,1e-1)",
    ],
    dataset="s3://my-bucket/imagenet",
    mode="remote",
)
```

Whenever an override starts with `hydra/`, the plugin invokes the full Hydra runtime so plugin discovery (sweepers, launchers, callbacks) can run. Pure value overrides on the `hydra.*` namespace (for example `hydra.run.dir=...`) do not need the full runtime and are applied per-job by the launcher directly.

### Forwarding `flyte.with_runcontext` options

Use `run_options` to pass Flyte runtime options through to every job:

```python{hl_lines=["8-14"]}
runs = hydra_sweep(
    pipeline,
    config_path="conf",
    config_name="training",
    overrides=["optimizer.lr=0.001,0.01,0.1"],
    dataset="s3://my-bucket/imagenet",
    mode="remote",
    run_options={
        "name": "my-training-sweep",
        "service_account": "default",
        "copy_style": "all",
        "raw_data_path": "s3://my-bucket/raw-data",
        "debug": True,
    },
)
```

## Flyte CLI (`flyte hydra run`)

`flyte hydra run` is registered through the `flyte.plugins.cli.commands` entry point. It loads a task from a Python file, composes a Hydra config, and runs the task without requiring the script to have its own `@hydra.main` function. It also inherits the relevant flags from `flyte run` (`--project`, `--domain`, `--image`, `--name`, `--service-account`, `--raw-data-path`, `--copy-style`, `--debug`, `--local`, `--follow`).

### Single run

Remote (default):

```bash
flyte hydra run --config-path conf --config-name training \
  train.py pipeline --dataset s3://my-bucket/imagenet
```

Forced local:

```bash{hl_lines=[1]}
flyte hydra run --local --config-path conf --config-name training \
  train.py pipeline --dataset s3://my-bucket/imagenet
```

### Grid sweep

```bash{hl_lines=[4]}
flyte hydra run --multirun --config-path conf --config-name training \
  --wait-max-workers 64 \
  train.py pipeline --dataset s3://my-bucket/imagenet \
  --cfg "optimizer.lr=0.001,0.01,0.1" --cfg "training.epochs=10,20"
```

### App-level vs Hydra-namespace overrides

The CLI keeps app-level overrides separate from Hydra runtime overrides so they do not collide with ordinary Flyte task arguments.

App-level overrides target the composed config and are passed through the **task's `DictConfig` parameter name**. For `pipeline(cfg: DictConfig, ...)`, use `--cfg`. For `pipeline_with_config(config: DictConfig, ...)`, use `--config`:

```bash{hl_lines=["3-4", 8]}
flyte hydra run --config-path conf --config-name training \
  train.py pipeline \
  --cfg optimizer.lr=0.01 \
  --cfg training.epochs=20

flyte hydra run --config-path conf --config-name training \
  train.py pipeline_with_config \
  --config optimizer.lr=0.01
```

Hydra runtime overrides: Anything in the `hydra.*` or `hydra/*` namespace go through `--hydra-override`:

```bash{hl_lines=[3, 4]}
flyte hydra run --config-path conf --config-name training \
  train.py pipeline \
  --hydra-override hydra.run.dir=./outputs/exp1 \
  --hydra-override hydra/launcher=flyte
```

Custom sweepers combine the two:

```bash{hl_lines=["3-7"]}
flyte hydra run --multirun --config-path conf --config-name training \
  train.py pipeline --dataset s3://my-bucket/imagenet \
  --hydra-override hydra/sweeper=optuna \
  --hydra-override hydra.sweeper.n_trials=20 \
  --hydra-override hydra.sweeper.n_jobs=4 \
  --cfg "optimizer.lr=interval(1e-4,1e-1)" \
  --cfg "training.epochs=choice(10,20,50)"
```

### `--follow` and `--no-wait`

`--follow` streams logs from the launched run after submission; it implies waiting and cannot be combined with `--no-wait`. `--no-wait` returns immediately after submission and skips log streaming.

### Shell completion

Install Click's completion hook for the `flyte` executable. For zsh:

```zsh
eval "$(_FLYTE_COMPLETE=zsh_source flyte)"
```

For bash:

```bash
eval "$(_FLYTE_COMPLETE=bash_source flyte)"
```

Once installed, `flyte hydra run` adds Hydra-aware completion after `SCRIPT TASK_NAME`. The command imports the script, inspects the task signature, and suggests:

- The app override flag matching the task's `DictConfig` parameter (`--cfg`, `--config`, ...).
- Override values for that flag and `--hydra-override` via Hydra's own completion engine, including config keys, config-group selections and sweep functions.

```bash{hl_lines=["2-3", "6-7"]}
flyte hydra run --config-path conf --config-name training \
  train.py pipeline --cfg optimizer.<TAB>
# suggests optimizer.lr=, optimizer.weight_decay=, ...

flyte hydra run --config-path conf --config-name training \
  train.py pipeline --hydra-override hydra/launcher=<TAB>
# suggests hydra launcher choices
```

Because completion has to import the target script, keep task definitions and `ConfigStore` registration import-safe, and avoid expensive top-level work in scripts you reach via `flyte hydra run`.

![Auto Completion](https://raw.githubusercontent.com/unionai/unionai-docs-static/main/images/integrations/hydra/auto_complete.gif)

## Override grammar

The override grammar is identical to standard Hydra; what differs is only how you pass the strings (positional in `python train.py ...`, list entries in `overrides=[...]`, repeated `--cfg`/`--hydra-override` on the Flyte CLI).

| Form                               | Meaning                                                                                  |
| ---------------------------------- | ---------------------------------------------------------------------------------------- |
| `optimizer.lr=0.01`                | Set an existing key.                                                                     |
| `optimizer=sgd`                    | Select a config group (replaces the `optimizer` subtree with `conf/optimizer/sgd.yaml`). |
| `+task_env=a100`                   | Append a config group whose key is not currently in the config.                          |
| `+training.grad_clip=1.0`          | Append a key that does not exist.                                                        |
| `++optimizer.lr=0.05`              | Force-set a key, creating it if missing and overriding strict-schema errors.             |
| `~training.warmup_steps`           | Delete a key from the composed config.                                                   |
| `optimizer.lr=0.001,0.01,0.1`      | Sweep value (with `--multirun`); expanded into one job per element.                      |
| `optimizer.lr=interval(1e-4,1e-1)` | Continuous sweep range; consumed by samplers like Optuna.                                |
| `optimizer=choice(adam,sgd)`       | Categorical sweep; consumed by samplers.                                                 |
| `hydra.run.dir=./outputs/exp1`     | Hydra-namespace value override (single run output dir).                                  |
| `hydra.sweep.dir=./outputs/sweep1` | Hydra-namespace sweep output dir.                                                        |
| `hydra/sweeper=optuna`             | Hydra-namespace config group selection (activates the Optuna sweeper plugin).            |

## Sweeps

### Grid sweeps (BasicSweeper)

Comma-separated overrides expand into a Cartesian product. The plugin uses Hydra's `BasicSweeper` to expand them, then submits one Flyte execution per combination.

```python{hl_lines=[1, 4, 7]}
from flyteplugins.hydra import hydra_sweep

runs = hydra_sweep(
    pipeline,
    config_path="conf", config_name="training",
    overrides=["model=resnet,vit", "optimizer.lr=0.001,0.01,0.1"],
    dataset="s3://my-bucket/imagenet",
    mode="remote",
)  # 6 executions
```

```bash{hl_lines=[3]}
flyte hydra run --multirun --config-path conf --config-name training \
  train.py pipeline --dataset s3://my-bucket/imagenet \
  --cfg "model=resnet,vit" --cfg "optimizer.lr=0.001,0.01,0.1"
```

Hardware presets can sweep alongside hyperparameters:

```bash{hl_lines=[3]}
flyte hydra run --multirun --config-path conf --config-name training \
  train.py pipeline --dataset s3://my-bucket/imagenet \
  --cfg "+task_env=a10g,a100" --cfg "optimizer.lr=0.001,0.01,0.1"
```

### Bayesian / TPE sweeps (Optuna)

Install the sweeper, then activate it via `hydra/sweeper=optuna`. Continuous parameters use `interval(...)`; categorical parameters use `choice(...)`.

```bash
pip install hydra-optuna-sweeper
```

```bash{hl_lines=["3-8"]}
flyte hydra run --multirun --config-path conf --config-name training \
  train.py pipeline --dataset s3://my-bucket/imagenet \
  --hydra-override "hydra/sweeper=optuna" \
  --hydra-override "hydra.sweeper.n_trials=30" \
  --hydra-override "hydra.sweeper.n_jobs=5" \
  --cfg "optimizer.lr=interval(1e-4,1e-1)" \
  --cfg "optimizer.weight_decay=interval(1e-6,1e-2)" \
  --cfg "model=choice(resnet,vit)"
```

When `wait=True`, each remote run's wrapped result exposes the task output as a float (via `__float__`), so Optuna can use it directly as the trial objective. With `wait=False`, the sweeper sees the run URL but cannot read objective values; use this only for fire-and-forget submission.

Other sweepers that respect Hydra's plugin protocol are activated the same way: install the package, select `hydra/sweeper=<name>`, and set the sweeper's parameters under `hydra.sweeper.*`.

### Sweep output directories

Hydra-namespace overrides redirect where Hydra writes per-job logs and config snapshots:

```bash{hl_lines=[3, 4]}
flyte hydra run --multirun --config-path conf --config-name training \
  train.py pipeline --dataset s3://my-bucket/imagenet \
  --hydra-override "hydra.sweep.dir=./outputs/sweep1" \
  --hydra-override "hydra.sweep.subdir=\${hydra.job.num}" \
  --cfg "optimizer.lr=0.001,0.01,0.1"
```

## Task environment overrides

Hydra is good at composing flat YAML; Flyte tasks need richer settings such as resources and container images. The plugin reserves a config key named `task_env` by default that maps task names to `task.override` kwargs.

```yaml
task_env:
  pipeline:
    resources:
      cpu: "2"
      memory: 8Gi
  train_model:
    resources:
      cpu: "16"
      memory: 64Gi
      gpu: "A100:1"
```

When the plugin launches a task, it looks up `task_env[<entry-task-name>]` (`pipeline` in this example) and applies the values via `task.override(...)`. Resource mappings are converted into `flyte.Resources(**values)` automatically.

### Prebuilt images

To run a task in a prebuilt container image, set `image` (and optionally `primary_container_name`):

```yaml{hl_lines=[3]}
task_env:
  pipeline:
    image: ghcr.io/acme/flyte-training:latest
    primary_container_name: main
    resources:
      cpu: "4"
      memory: 16Gi
```

`task.override` does not accept `image` directly. The task image is part of the task definition. Instead, the plugin lowers the override to a `flyte.PodTemplate` whose primary container uses the requested image:

- If the task has no inline pod template, a new one is created.
- If the task already has an inline `flyte.PodTemplate`, the plugin deep-copies it and sets only the image on the primary container.
- If the task references a pod template by name (a string), the plugin raises an error. You must patch a string-named template by editing it in cluster config rather than at submission time.

### Applying overrides to child tasks

The launcher only controls the entry task it submits. Child tasks called from within the entry task are not patched automatically. Use `apply_task_env` to apply the same `resources`/`image` handling to a child task before invoking it:

```python{hl_lines=[1, 7]}
from flyteplugins.hydra import apply_task_env

@env.task
async def pipeline(cfg: DictConfig, dataset: str) -> float:
    data = await preprocess(cfg)
    train_task = apply_task_env(train_model, cfg)
    _, val_loss = await train_task(cfg, data)
    return val_loss
```

This keeps the override knobs in YAML/CLI surfaces while leaving each task in control of which children it patches.

### Renaming the task-env key

If your config uses a different name for the task-env subtree, pass it explicitly:

```python
hydra_run(..., task_env_key="task_environment")
```

```bash
flyte hydra run --task-env-key task_environment ...
```

### What `task_env` should not model

The YAML schema intentionally omits the full Kubernetes `V1PodSpec`. Keep advanced pod configuration (volumes, init containers, node selectors, etc.) in Python task/environment code where you have a real type. Use Hydra `task_env` presets for the common knobs only: image, primary container name and resources.

## Structured configs (without YAML)

Structured configs work with this plugin as long as they are registered before the launcher composes the config. `flyte hydra run` imports the script first, so top-level `ConfigStore.instance().store(...)` calls run before composition.

```python{hl_lines=[17]}
from dataclasses import dataclass, field
from hydra.core.config_store import ConfigStore
from omegaconf import DictConfig

@dataclass
class TrainingConf:
    epochs: int = 30
    batch_size: int = 64

@dataclass
class RootConf:
    training: TrainingConf = field(default_factory=TrainingConf)

ConfigStore.instance().store(name="structured_training", node=RootConf)
```

Run a fully-structured config without YAML:

```bash{hl_lines=[1]}
flyte hydra run --config-name structured_training \
  train.py pipeline --dataset s3://my-bucket/imagenet
```

The same config also works through `@hydra.main`:

```bash
python train.py --config-name structured_training
```

If the structured config still references YAML config groups, keep `--config-path conf`. If everything is registered in `ConfigStore`, omit `--config-path`.

> **⚠️ Warning**
>
> Do not register structured configs only inside `if __name__ == "__main__":` or inside the `@hydra.main` function body. `flyte hydra run` and shell completion inspect the script at import time, before either of those blocks runs, and registrations placed there will not be visible.

Structured configs sweep just like YAML configs:

```python{hl_lines=[4, 5]}
runs = hydra_sweep(
    pipeline,
    config_path=None,
    config_name="structured_training",
    overrides=["training.epochs=10,20", "training.batch_size=32,64"],
    dataset="s3://my-bucket/imagenet",
    mode="remote",
)
```

=== PAGE: https://www.union.ai/docs/v2/union/integrations/mlflow ===

# MLflow

The MLflow plugin integrates [MLflow](https://mlflow.org/) experiment tracking with Flyte. It provides a `@mlflow_run` decorator that automatically manages MLflow runs within Flyte tasks, with support for autologging, parent-child run sharing, distributed training, and auto-generated UI links.

The decorator works with both sync and async tasks.

## Installation

```bash
pip install flyteplugins-mlflow
```

Requires `mlflow` and `flyte`.

## Quick start

```python{hl_lines=[3, 9, "13-16", 22]}
import flyte
import mlflow
from flyteplugins.mlflow import mlflow_run, get_mlflow_run

env = flyte.TaskEnvironment(
    name="mlflow-tracking",
    resources=flyte.Resources(cpu=1, memory="500Mi"),
    image=flyte.Image.from_debian_base(name="mlflow_example").with_pip_packages(
        "flyteplugins-mlflow"
    ),
)

@mlflow_run(
    tracking_uri="http://localhost:5000",
    experiment_name="my-experiment",
)
@env.task
async def train_model(learning_rate: float) -> str:
    mlflow.log_param("lr", learning_rate)
    mlflow.log_metric("loss", 0.42)

    run = get_mlflow_run()
    return run.info.run_id
```

![Link](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/integrations/mlflow/link.png)

![Mlflow UI](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/integrations/mlflow/mlflow_dashboard.png)

> [!NOTE]
> `@mlflow_run` must be the outermost decorator, before `@env.task`:
>
> ```python{hl_lines=["1-2"]}
> @mlflow_run          # outermost
> @env.task            # innermost
> async def my_task(): ...
> ```

## Autologging

Enable MLflow's autologging to automatically capture parameters, metrics, and models without manual `mlflow.log_*` calls.

### Generic autologging

```python{hl_lines=[1]}
@mlflow_run(autolog=True)
@env.task
async def train():
    from sklearn.linear_model import LogisticRegression

    model = LogisticRegression()
    model.fit(X, y)  # Parameters, metrics, and model are logged automatically
```

### Framework-specific autologging

Pass `framework` to use a framework-specific autolog implementation:

```python{hl_lines=[3]}
@mlflow_run(
    autolog=True,
    framework="sklearn",
    log_models=True,
    log_datasets=False,
)
@env.task
async def train_sklearn():
    from sklearn.ensemble import RandomForestClassifier

    model = RandomForestClassifier(n_estimators=100)
    model.fit(X_train, y_train)
```

Supported frameworks include any framework with an `mlflow.{framework}.autolog()` function. You can find the full list of supported frameworks [here](https://mlflow.org/docs/latest/ml/tracking/autolog/#supported-libraries).

You can pass additional autolog parameters via `autolog_kwargs`:

```python{hl_lines=[4]}
@mlflow_run(
    autolog=True,
    framework="pytorch",
    autolog_kwargs={"log_every_n_epoch": 5},
)
@env.task
async def train_pytorch():
    ...
```

![Autolog](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/integrations/mlflow/autolog.png)

## Run modes

The `run_mode` parameter controls how MLflow runs are created and shared across tasks:

| Mode               | Behavior                                                              |
| ------------------ | --------------------------------------------------------------------- |
| `"auto"` (default) | Reuse the parent's run if one exists, otherwise create a new run      |
| `"new"`            | Always create a new independent run                                   |
| `"nested"`         | Create a new run nested under the parent via `mlflow.parentRunId` tag |

### Sharing a run across tasks

With `run_mode="auto"` (the default), child tasks reuse the parent's MLflow run:

```python{hl_lines=[1, 5, 7]}
@mlflow_run
@env.task
async def parent_task():
    mlflow.log_param("stage", "parent")
    await child_task()  # Shares the same MLflow run

@mlflow_run
@env.task
async def child_task():
    mlflow.log_metric("child_metric", 1.0)  # Logged to the parent's run
```

### Creating independent runs

Use `run_mode="new"` when a task should always create its own top-level MLflow run, completely independent of any parent:

```python{hl_lines=[1]}
@mlflow_run(run_mode="new")
@env.task
async def standalone_experiment():
    mlflow.log_param("experiment_type", "baseline")
    mlflow.log_metric("accuracy", 0.95)
```

### Nested runs

Use `run_mode="nested"` to create a child run that appears under the parent in the MLflow UI. This works across processes and containers via the `mlflow.parentRunId` tag.

![Nested runs](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/integrations/mlflow/mlflow_hpo.png)

This is the recommended pattern for hyperparameter optimization, where each trial should be tracked as a child of the parent study run:

```python{hl_lines=[1, 2, 15, "22-25"]}
from flyteplugins.mlflow import Mlflow

@mlflow_run(run_mode="nested")
@env.task(links=[Mlflow()])
async def run_trial(trial_number: int, n_estimators: int, max_depth: int) -> float:
    """Each trial creates a nested MLflow run under the parent."""
    mlflow.log_params({"n_estimators": n_estimators, "max_depth": max_depth})
    mlflow.log_param("trial_number", trial_number)

    model = RandomForestRegressor(n_estimators=n_estimators, max_depth=max_depth)
    model.fit(X_train, y_train)

    rmse = float(np.sqrt(mean_squared_error(y_val, model.predict(X_val))))
    mlflow.log_metric("rmse", rmse)
    return rmse

@mlflow_run
@env.task
async def hpo_search(n_trials: int = 30) -> str:
    """Parent run tracks the overall study."""
    run = get_mlflow_run()
    mlflow.log_param("n_trials", n_trials)

    # Run trials in parallel — each gets a nested MLflow run
    rmses = await asyncio.gather(
        *(run_trial(trial_number=i, **params) for i, params in enumerate(trial_params))
    )

    mlflow.log_metric("best_rmse", min(rmses))
    return run.info.run_id
```

![HPO](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/integrations/mlflow/hpo.png)

## Workflow-level configuration

Use `mlflow_config()` with `flyte.with_runcontext()` to set MLflow configuration for an entire workflow. All `@mlflow_run`-decorated tasks in the workflow inherit these settings:

```python{hl_lines=[1, "4-8"]}
from flyteplugins.mlflow import mlflow_config

r = flyte.with_runcontext(
    custom_context=mlflow_config(
        tracking_uri="http://localhost:5000",
        experiment_id="846992856162999",
        tags={"team": "ml"},
    )
).run(train_model, learning_rate=0.001)
```

This eliminates the need to repeat `tracking_uri` and experiment settings on every `@mlflow_run` decorator.

### Per-task overrides

Use `mlflow_config()` as a context manager inside a task to override configuration for specific child tasks:

```python{hl_lines=[6]}
@mlflow_run
@env.task
async def parent_task():
    await shared_child()  # Inherits parent config

    with mlflow_config(run_mode="new", tags={"role": "independent"}):
        await independent_child()  # Gets its own run
```

### Configuration priority

Settings are resolved in priority order:

1. Explicit `@mlflow_run` decorator arguments
2. `mlflow_config()` context configuration
3. Environment variables (for `tracking_uri`)
4. MLflow defaults

## Distributed training

In distributed training, only rank 0 logs to MLflow by default. The plugin detects rank automatically from the `RANK` environment variable:

```python{hl_lines=[1, "4-6"]}
@mlflow_run
@env.task
async def distributed_train():
    # Only rank 0 creates an MLflow run and logs metrics.
    # Other ranks execute the task function directly without
    # creating an MLflow run or incurring any MLflow overhead.
    ...
```

On non-rank-0 workers, no MLflow run is created and `get_mlflow_run()` returns `None`. The task function still executes normally — only the MLflow instrumentation is skipped.

![Distributed training](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/integrations/mlflow/distributed_training.png)

You can also set rank explicitly:

```python{hl_lines=[1]}
@mlflow_run(rank=0)
@env.task
async def train():
    ...
```

## MLflow UI links

The `Mlflow` link class displays links to the MLflow UI in the Flyte UI.

Since the MLflow run is created inside the task at execution time, the run URL cannot be determined before the task starts. Links are only shown when a run URL is already available from context, either because a parent task created the run, or because an explicit URL is provided.

The recommended pattern is for the parent task to create the MLflow run, and child tasks that inherit the run (via `run_mode="auto"`) display the link to that run. For nested runs (`run_mode="nested"`), children display a link to the parent run.

### Setup

Set `link_host` via `mlflow_config()` and attach `Mlflow()` links to child tasks:

```python{hl_lines=[4, 17]}
from flyteplugins.mlflow import Mlflow, mlflow_config

@mlflow_run
@env.task(links=[Mlflow()])
async def child_task():
    ...  # Link points to the parent's MLflow run

@mlflow_run
@env.task
async def parent_task():
    await child_task()

if __name__ == "__main__":
    r = flyte.with_runcontext(
        custom_context=mlflow_config(
            tracking_uri="http://localhost:5000",
            link_host="http://localhost:5000",
        )
    ).run(parent_task)
```

> [!NOTE]
> `Mlflow()` is instantiated without a `link` argument because the URL is auto-generated at runtime. When the parent task creates an MLflow run, the plugin builds the URL from `link_host` and the run's experiment/run IDs, then propagates it to child tasks via the Flyte context. Passing an explicit `link` would bypass this auto-generation.

### Custom URL templates

The default link format is:

```
{host}/#/experiments/{experiment_id}/runs/{run_id}
```

For platforms like Databricks that use a different URL structure, provide a custom template:

```python{hl_lines=[3]}
mlflow_config(
    link_host="https://dbc-xxx.cloud.databricks.com",
    link_template="{host}/ml/experiments/{experiment_id}/runs/{run_id}",
)
```

### Explicit links

If you know the run URL ahead of time, you can set it directly:

```python{hl_lines=[1]}
@env.task(links=[Mlflow(link="https://mlflow.example.com/#/experiments/1/runs/abc123")])
async def my_task():
    ...
```

### Link behavior by run mode

| Run mode   | Link behavior                                                                                  |
| ---------- | ---------------------------------------------------------------------------------------------- |
| `"auto"`   | Parent link propagates to child tasks sharing the run                                          |
| `"new"`    | Parent link is cleared; no link is shown until the task's own run is available to its children |
| `"nested"` | Parent link is kept and renamed to "MLflow (parent)"                                           |

## Automatic Flyte tags

When running inside Flyte, the plugin automatically tags MLflow runs with execution metadata:

| Tag                 | Description      |
| ------------------- | ---------------- |
| `flyte.action_name` | Task action name |
| `flyte.run_name`    | Flyte run name   |
| `flyte.project`     | Flyte project    |
| `flyte.domain`      | Flyte domain     |

These tags are merged with any user-provided tags.

## API reference

### `mlflow_run` and `mlflow_config`

`mlflow_run` is a decorator that manages MLflow runs for Flyte tasks. `mlflow_config` creates workflow-level configuration or per-task overrides. Both accept the same core parameters:

| Parameter         | Type             | Default  | Description                                                                   |
| ----------------- | ---------------- | -------- | ----------------------------------------------------------------------------- |
| `run_mode`        | `str`            | `"auto"` | `"auto"`, `"new"`, or `"nested"`                                              |
| `tracking_uri`    | `str`            | `None`   | MLflow tracking server URL                                                    |
| `experiment_name` | `str`            | `None`   | MLflow experiment name (raises `ValueError` if combined with `experiment_id`) |
| `experiment_id`   | `str`            | `None`   | MLflow experiment ID (raises `ValueError` if combined with `experiment_name`) |
| `run_name`        | `str`            | `None`   | Human-readable run name (raises `ValueError` if combined with `run_id`)       |
| `run_id`          | `str`            | `None`   | Explicit MLflow run ID (raises `ValueError` if combined with `run_name`)      |
| `tags`            | `dict[str, str]` | `None`   | Tags for the run                                                              |
| `autolog`         | `bool`           | `False`  | Enable MLflow autologging                                                     |
| `framework`       | `str`            | `None`   | Framework for autolog (e.g. `"sklearn"`, `"pytorch"`)                         |
| `log_models`      | `bool`           | `None`   | Log models automatically (requires `autolog`)                                 |
| `log_datasets`    | `bool`           | `None`   | Log datasets automatically (requires `autolog`)                               |
| `autolog_kwargs`  | `dict`           | `None`   | Extra parameters for `mlflow.autolog()`                                       |

Additional keyword arguments are passed to `mlflow.start_run()`.

`mlflow_run` also accepts:

| Parameter | Type  | Default | Description                                              |
| --------- | ----- | ------- | -------------------------------------------------------- |
| `rank`    | `int` | `None`  | Process rank for distributed training (only rank 0 logs) |

`mlflow_config` also accepts:

| Parameter       | Type  | Default | Description                                                                 |
| --------------- | ----- | ------- | --------------------------------------------------------------------------- |
| `link_host`     | `str` | `None`  | MLflow UI host for auto-generating links                                    |
| `link_template` | `str` | `None`  | Custom URL template (placeholders: `{host}`, `{experiment_id}`, `{run_id}`) |

### `get_mlflow_run`

Returns the current `mlflow.ActiveRun` if within a `@mlflow_run`-decorated task. Returns `None` otherwise.

```python
from flyteplugins.mlflow import get_mlflow_run

run = get_mlflow_run()
if run:
    print(run.info.run_id)
```

### `get_mlflow_context`

Returns the current `mlflow_config` settings from the Flyte context, or `None` if no MLflow configuration is set. Useful for inspecting the inherited configuration inside a task:

```python
from flyteplugins.mlflow import get_mlflow_context

@mlflow_run
@env.task
async def my_task():
    config = get_mlflow_context()
    if config:
        print(config.tracking_uri, config.experiment_id)
```

### `Mlflow`

Link class for displaying MLflow UI links in the Flyte console.

| Field  | Type  | Default    | Description                             |
| ------ | ----- | ---------- | --------------------------------------- |
| `name` | `str` | `"MLflow"` | Display name for the link               |
| `link` | `str` | `""`       | Explicit URL (bypasses auto-generation) |

=== PAGE: https://www.union.ai/docs/v2/union/integrations/omegaconf ===

# OmegaConf

[OmegaConf](https://omegaconf.readthedocs.io/) is a hierarchical configuration system used by many ML frameworks (and the foundation of [Hydra](../hydra/_index)). The `flyteplugins-omegaconf` plugin makes OmegaConf's `DictConfig` and `ListConfig` first-class types in Flyte tasks, so you can pass entire configs like plain dicts, YAML files or dataclass-backed structured configs between tasks without flattening them into individual scalar arguments.

The plugin enables:

- `DictConfig` and `ListConfig` as native task input and output types
- Round-tripping of structured configs (dataclass schemas) across task boundaries
- Preservation of OmegaConf-specific values: `MISSING` sentinels, `Enum`s, `pathlib.Path`s, `tuple`s, and `bytes`
- Resolved variable interpolations on the wire
- A YAML-rendered Flyte report tab for human-readable config inspection

## Installation

```bash
pip install flyteplugins-omegaconf
```

Installing the package automatically registers `DictConfig` and `ListConfig` with Flyte's `TypeEngine`. No manual setup is required.

If you are using the [Hydra plugin](../hydra/_index), `flyteplugins-omegaconf` is installed as a transitive dependency.

## Quick start

```python{hl_lines=[2, "8-9", "14-17"]}
import flyte
from omegaconf import DictConfig, OmegaConf

env = flyte.TaskEnvironment(name="training", image=...)

@env.task
async def train(cfg: DictConfig) -> float:
    return run_experiment(cfg.optimizer.lr, cfg.training.epochs)

@env.task
async def pipeline() -> float:
    cfg = OmegaConf.create(
        {"optimizer": {"lr": 0.001}, "training": {"epochs": 10}}
    )
    return await train(cfg)
```

The config is serialized when `train` is invoked and reconstructed as a `DictConfig` inside the task. No type registration, manual encoding or schema declaration is required.

## When to use this plugin

Use `flyteplugins-omegaconf` when:

- You already use OmegaConf. For example, you have YAML configs, dataclass-based config trees or a Hydra app, and want to keep that representation intact across task boundaries.
- You want to pass a single composed config object instead of widening task signatures with dozens of scalar arguments.
- You want to enforce schema validation at the task entry point via dataclass-backed structured configs.
- You want resolved interpolations (`${other.value}`) to be materialized at submission time rather than at task runtime.

If you do not use OmegaConf elsewhere, prefer plain dataclasses, `pydantic.BaseModel` or `dict` for task inputs as they are supported by Flyte natively without an extra dependency.

## Building a DictConfig

Any of the standard OmegaConf construction methods produce a value the plugin can serialize.

### From a plain dict

```python{hl_lines=["1-3"]}
cfg = OmegaConf.create(
    {"optimizer": {"lr": 0.001}, "training": {"epochs": 10}}
)
flyte.run(train, cfg=cfg)
```

### From a YAML file

```python{hl_lines=[1]}
cfg = OmegaConf.load("configs/training.yaml")
flyte.run(train, cfg=cfg)
```

The file is read locally on the submitter, not on the worker. If the YAML lives in your project tree and needs to be packaged into the task image, use `flyte.with_runcontext(copy_style="all").run(...)`.

### From a dataclass (structured config)

```python{hl_lines=["3-6", 8]}
from dataclasses import dataclass

@dataclass
class TrainConf:
    lr: float = 0.001
    epochs: int = 10

cfg = OmegaConf.structured(TrainConf())
flyte.run(train, cfg=cfg)
```

Structured configs are covered in detail in **OmegaConf > Structured configs** below.

### From a base config plus overrides

```python{hl_lines=["1-3"]}
base = OmegaConf.load("configs/training.yaml")
override = OmegaConf.create({"optimizer": {"lr": 0.01}})
cfg = OmegaConf.merge(base, override)
flyte.run(train, cfg=cfg)
```

This is the same pattern Hydra uses internally. See the [Hydra integration](../hydra/_index) for a full composition layer on top of this plugin.

## Variable interpolation

OmegaConf supports `${...}` interpolations that resolve relative to the config tree:

```python{hl_lines=[3, 4]}
cfg = OmegaConf.create(
    {
        "base_lr": 0.01,
        "optimizer": {"lr": "${base_lr}", "momentum": 0.9},
    }
)
flyte.run(train, cfg=cfg)
```

Interpolations are resolved at serialization time. By the time the task runs, `cfg.optimizer.lr` is the concrete float `0.01`, not the string `"${base_lr}"`. This means:

- The receiving task does not need any context that only existed in the submitter's environment.
- Resolved values appear in the Flyte I/O panel.
- A reference that fails to resolve at submission time fails fast, before any task runs.

If you need lazy resolution on the worker, resolve the reference yourself inside the task or pass the unresolved string through a normal `str` input.

## Nested and deeply structured configs

Nested configs are supported, including deeply structured OmegaConf objects.

```python{hl_lines=["1-13", 18]}
cfg = OmegaConf.create(
    {
        "experiment": {
            "model": {
                "encoder": {
                    "attention": {"num_heads": 8, "head_dim": 64},
                    "ffn": {"hidden_dim": 2048, "activation": "gelu"},
                },
                "decoder": {"num_layers": 6},
            }
        }
    }
)

@env.task
async def extract_leaf(cfg: DictConfig) -> int:
    return int(cfg.experiment.model.encoder.attention.num_heads)
```

## DictConfigs that contain lists

A `DictConfig` may hold list values; they are reconstructed as nested `ListConfig`s on the receiving side.

```python{hl_lines=[4, 5, 8, 9]}
cfg = OmegaConf.create(
    {
        "model": {
            "layer_sizes": [64, 128, 256, 512],
            "activations": ["relu", "relu", "relu", "sigmoid"],
        },
        "data": {
            "augmentations": ["random_flip", "random_crop", "color_jitter"],
            "input_size": [224, 224],
        },
    }
)

@env.task
async def double_layer_sizes(cfg: DictConfig) -> DictConfig:
    doubled = [size * 2 for size in cfg.model.layer_sizes]
    return OmegaConf.merge(cfg, {"model": {"layer_sizes": doubled}})
```

## ListConfig as input and output

`ListConfig` is symmetric with `DictConfig` and supports the same construction patterns.

### Lists of primitives

```python{hl_lines=[2]}
@env.task
async def scale_values(values: ListConfig, factor: float) -> ListConfig:
    return OmegaConf.create([v * factor for v in values])
```

### Building a schedule from another task

```python{hl_lines=[3, 7, 8]}
@env.task
async def build_lr_schedule(base_lr: float, num_stages: int) -> ListConfig:
    return OmegaConf.create([base_lr * (0.5 ** i) for i in range(num_stages)])

@env.task
async def train_with_schedule(cfg: DictConfig, lr_schedule: ListConfig) -> float:
    final_lr = float(lr_schedule[-1])
    ...
```

### Nested lists (list of lists)

```python{hl_lines=[1, 6]}
grid = OmegaConf.create([[0.001, 0.01, 0.1], [10, 20, 50]])

@env.task
async def flatten_grid(grid: ListConfig) -> ListConfig:
    flat = [item for sublist in OmegaConf.to_container(grid) for item in sublist]
    return OmegaConf.create(flat)
```

### Lists of DictConfigs

```python{hl_lines=["2-6"]}
configs = OmegaConf.create(
    [
        {"optimizer": {"lr": 0.001}, "training": {"epochs": 10}},
        {"optimizer": {"lr": 0.01},  "training": {"epochs": 20}},
        {"optimizer": {"lr": 0.1},   "training": {"epochs": 5}},
    ]
)

@env.task
async def select_best_config(configs: ListConfig) -> DictConfig:
    best = max(OmegaConf.to_container(configs), key=lambda c: c["optimizer"]["lr"])
    return OmegaConf.create(best)
```

### Lists of dataclass instances

```python{hl_lines=["9-13"]}
@dataclass
class LayerConf:
    name: str
    width: int
    activation: str

layers = OmegaConf.create(
    [
        LayerConf(name="encoder", width=768, activation="gelu"),
        LayerConf(name="bottleneck", width=128, activation="relu"),
        LayerConf(name="decoder", width=768, activation="linear"),
    ]
)
```

Each element round-trips as a typed `DictConfig` backed by `LayerConf`, so the receiving task can call `OmegaConf.get_type(layers[0])` and access fields with attribute notation.

> **📝 Note**
>
> ListConfig is always plain. Even when its elements are dataclass-backed, the outer `ListConfig` does not carry a list-level schema as there is no structured (typed-element) `ListConfig` in OmegaConf. This affects only the outer container; nested elements retain their schemas.

## Structured configs

A structured config is a `DictConfig` that is bound to a Python dataclass. The dataclass acts as a schema: assigning a value of the wrong type raises `omegaconf.ValidationError`, and merging unknown keys raises an error instead of silently extending the config.

### Basic structured config

```python{hl_lines=["5-8", "11-14", 17, 20]}
from dataclasses import dataclass, field
from omegaconf import OmegaConf, DictConfig

@dataclass
class OptimizerConf:
    lr: float = 0.001
    weight_decay: float = 1e-4

@dataclass
class TrainConf:
    optimizer: OptimizerConf = field(default_factory=OptimizerConf)
    epochs: int = 10

cfg = OmegaConf.structured(TrainConf())
flyte.run(train, cfg=cfg)

# cfg.optimizer.lr = "oops"  # raises omegaconf.ValidationError
```

### Schema reconstruction in the receiving task

When a structured `DictConfig` is deserialized in a downstream task, the plugin operates in **Auto mode**: it reads the originating dataclass name from the wire payload and tries to import it. Two outcomes are possible:

- Dataclass importable in the receiving task: `cfg` is reconstructed as a `TrainConf`-backed `DictConfig`. `OmegaConf.get_type(cfg)` returns `TrainConf`, and type validation is enforced.
- Dataclass not importable: `cfg` falls back to a plain `DictConfig` carrying the raw values. `OmegaConf.get_type(cfg)` returns `dict`. The values are intact but the schema is lost.

To keep schemas across task hops, define dataclasses in modules that are importable from every task in the pipeline (for example, in a shared `configs.py` module bundled into the task image).

### Required (`MISSING`) fields

OmegaConf's `MISSING` sentinel marks a required field that has no default:

```python{hl_lines=[1, 5, "8-9", "12-13"]}
from omegaconf import MISSING

@dataclass
class TrainConf:
    data_path: str = MISSING
    epochs: int = 10

# Pass with MISSING still unset — serialization succeeds.
cfg = OmegaConf.structured(TrainConf())
flyte.run(train, cfg=cfg)

# Or fill it before passing.
cfg = OmegaConf.structured(TrainConf(data_path="/data/imagenet"))
flyte.run(train, cfg=cfg)
```

A config with an unset `MISSING` field serializes and deserializes successfully as the sentinel is preserved on the wire. Accessing the field on the receiving side raises `MissingMandatoryValue`.

> **📝 Note**
>
> Type annotations are preserved only in Auto mode. When the dataclass is importable on the receiving side, an unfilled `MISSING` field still carries its declared type (e.g. `StringNode` for `str`). When the plugin falls back to a plain `DictConfig` because the dataclass is not importable, the field becomes an `AnyNode` where the value is preserved, but the type annotation is not.

### Advanced field types

Beyond primitives and nested dataclasses, structured configs may declare fields of these types and they will round-trip with their schemas intact:

- `Enum` subclasses
- `pathlib.Path`
- `Optional[T]`
- `bytes`
- `dict[str, T]` where `T` is a dataclass
- `list[T]` where `T` is a dataclass

```python{hl_lines=["6-8", "20-35"]}
from enum import Enum
from pathlib import Path
from typing import Optional

class RunMode(Enum):
    TRAIN = "train"
    EVAL = "eval"

@dataclass
class CallbackConf:
    name: str = "early_stop"
    patience: int = 3
    monitor: str = MISSING

@dataclass
class AdvancedTrainConf:
    mode: RunMode = RunMode.TRAIN
    checkpoint_dir: Path = Path("/tmp/checkpoints")
    maybe_seed: Optional[int] = None
    payload: bytes = b"default-token"
    callbacks_by_name: dict[str, CallbackConf] = field(
        default_factory=lambda: {
            "early_stop": CallbackConf(name="early_stop", patience=3),
            "checkpoint": CallbackConf(name="checkpoint", monitor="val_loss"),
        }
    )
    callbacks: list[CallbackConf] = field(
        default_factory=lambda: [
            CallbackConf(name="lr_monitor", patience=2, monitor="lr"),
            CallbackConf(name="nan_guard", patience=1, monitor="loss"),
        ]
    )
```

Inside a downstream task:

```python
@env.task
async def inspect(cfg: DictConfig) -> str:
    assert OmegaConf.get_type(cfg) == AdvancedTrainConf
    assert OmegaConf.get_type(cfg.callbacks[0]) == CallbackConf
    assert isinstance(cfg.mode, RunMode)
    assert isinstance(cfg.checkpoint_dir, Path)
    assert isinstance(cfg.payload, bytes)
    return cfg.mode.value
```

### Merging overrides on top of a structured base

```python{hl_lines=[3, 11]}
@env.task
async def structured_merge_pipeline() -> str:
    base = OmegaConf.structured(TrainConf())
    overrides = OmegaConf.create(
        {
            "optimizer": {"lr": 0.05},
            "training": {"epochs": 100},
            "experiment_name": "sweep-run-1",
        }
    )
    cfg = OmegaConf.merge(base, overrides)
    return await validate_config(cfg)
```

Merging an unknown key against a structured config raises an error, so define every key the override layer might supply on the dataclass.

## Embedding rich Python values inside a plain DictConfig

A plain `DictConfig` (one not bound to a dataclass) can still hold Python values that OmegaConf does not natively model. The plugin preserves the following types end-to-end whether they appear in plain or structured configs:

- `pathlib.Path` and any subclass of `pathlib.PurePath`
- `enum.Enum` members
- `tuple` (round-trips as `tuple`, not `list`)
- `bytes`

```python{hl_lines=[1]}
cfg = OmegaConf.create({"model_path": Path("/opt/models/model.bin")})

@env.task
async def use_path(cfg: DictConfig) -> str:
    assert isinstance(cfg.model_path, Path)
    return f"model_path={cfg.model_path}"
```

If an `Enum`'s class cannot be imported in the receiving environment, the value is returned as the underlying primitive (`int`, `str`, ...) instead of the enum member.

## Reserved-looking keys

The plugin's wire format uses an internal payload marker (`__flyte_omegaconf__`), which means user-facing keys named `kind`, `values`, `name`, `value`, `type`, or `schema` round-trip unchanged:

```python{hl_lines=[1, 8]}
cfg = OmegaConf.create({"kind": "training-job", "values": {"lr": 0.001}})

@env.task
async def use_payload_shaped_config(cfg: DictConfig) -> str:
    # cfg.values resolves to DictConfig.values() — use bracket notation
    # to reach the user key named "values".
    return f"kind={cfg.kind} lr={cfg['values'].lr}"
```

The only practical consideration is Python's normal attribute-vs-method conflict: `cfg.values` is the `.values()` method, so reach for `cfg["values"]` when your config has a key with that name.

## YAML reports

The Flyte I/O panel displays the literal wire representation of a `DictConfig`.

![Wire Representation](https://raw.githubusercontent.com/unionai/unionai-docs-static/main/images/integrations/omegaconf/input.png)

For a YAML view, enable a Flyte report on the task and log the config with `log_yaml`:

```python{hl_lines=[1, 4, 6]}
from flyteplugins.omegaconf import log_yaml

@env.task(report=True)
async def train(cfg: DictConfig) -> DictConfig:
    await log_yaml.aio(cfg, title="Input config")
    ...
```

![YAML Report](https://raw.githubusercontent.com/unionai/unionai-docs-static/main/images/integrations/omegaconf/yaml_repr.png)

The plugin also exposes:

- `to_yaml(cfg)`: render an OmegaConf container as a YAML string.
- `to_html(cfg, title=...)`: wrap the YAML in escaped HTML for embedding in a custom report.
- `replace_yaml(cfg, ...)`: replace the contents of a report tab instead of appending.

```python
from flyteplugins.omegaconf.report import to_yaml, replace_yaml

text = to_yaml(cfg)
await replace_yaml.aio(cfg, tab="Final config")
```

`MISSING` fields appear as `???` in the YAML output, matching OmegaConf's own convention.

## Wire format

Both `DictConfig` and `ListConfig` are serialized as MessagePack blobs with the literal representation:

```
Literal(scalar=Scalar(binary=Binary(value=<msgpack bytes>, tag="msgpack")))
```

The msgpack payload uses an internal tagged structure to distinguish OmegaConf-specific concepts from raw values:

- A `DictConfig` payload includes the originating dataclass name (`builtins.dict` for plain configs) plus its values.
- `MISSING`, `Enum`, `Path`, and `tuple` values carry tagged shapes so they can be reconstructed faithfully.

You normally do not need to inspect this format. It is documented here because:

- The plugin serializes with `resolve=True`, so the wire representation always contains concrete values for `${...}` interpolations.
- Cache-key metadata is set via Flyte's `MESSAGEPACK` serialization format, so two tasks given equivalent configs hit the same cache entry.

## End-to-end example

The example below ties the pieces together: a structured `DictConfig` is created in a parent task, flows through several child tasks that read and modify it, and a `ListConfig` produced midway is consumed by a later stage. Each hop serializes and deserializes the config; the dataclass schema is recovered on the receiving side because `TrainConf` (and friends) are importable in every task in the pipeline.

```
from dataclasses import dataclass, field

import flyte
from omegaconf import DictConfig, ListConfig, OmegaConf

env = flyte.TaskEnvironment(
    name="omegaconf-pipeline-example",
    image=flyte.Image.from_debian_base().with_pip_packages("flyteplugins-omegaconf"),
)

@dataclass
class OptimizerConf:
    lr: float = 0.001
    weight_decay: float = 1e-4

@dataclass
class DataConf:
    path: str = ""
    preprocessed: bool = False

@dataclass
class ResultsConf:
    val_loss: float = 0.0
    final_lr: float = 0.0
    num_lr_steps: int = 0

@dataclass
class TrainConf:
    optimizer: OptimizerConf = field(default_factory=OptimizerConf)
    data: DataConf = field(default_factory=DataConf)
    results: ResultsConf = field(default_factory=ResultsConf)
    epochs: int = 10
    batch_size: int = 32
    experiment: str = "baseline"

@env.task
async def preprocess(cfg: DictConfig, dataset: str) -> DictConfig:
    """First stage: fills in the data section of cfg."""
    return OmegaConf.merge(cfg, {"data": {"path": dataset, "preprocessed": True}})

@env.task
async def build_schedule(cfg: DictConfig) -> ListConfig:
    """Produces an LR schedule from cfg as a ListConfig."""
    lrs = [cfg.optimizer.lr * (0.5**i) for i in range(cfg.epochs)]
    return OmegaConf.create(lrs)

@env.task
async def train(cfg: DictConfig, lr_schedule: ListConfig) -> tuple[DictConfig, float]:
    """Simulates training. Returns the final cfg (with results filled in) and val loss."""
    final_lr = float(lr_schedule[-1])
    val_loss = final_lr * 10  # placeholder
    result_cfg = OmegaConf.merge(
        cfg,
        {
            "results": {
                "val_loss": val_loss,
                "final_lr": final_lr,
                "num_lr_steps": len(lr_schedule),
            }
        },
    )
    return result_cfg, val_loss

@env.task
async def evaluate(result_cfg: DictConfig, val_loss: float) -> str:
    """Final stage: formats a report from the result config."""
    return (
        f"experiment={result_cfg.experiment} "
        f"data={result_cfg.data.path} "
        f"val_loss={val_loss:.6f} "
        f"final_lr={result_cfg.results.final_lr:.6f} "
        f"lr_steps={result_cfg.results.num_lr_steps}"
    )

@env.task
async def training_pipeline(dataset: str) -> str:
    """Full pipeline: cfg flows preprocess, build_schedule, train and evaluate."""
    cfg = OmegaConf.structured(
        TrainConf(
            optimizer=OptimizerConf(lr=0.01, weight_decay=1e-5),
            epochs=5,
            batch_size=64,
            experiment="structured-cfg-pipeline",
        )
    )

    preprocessed_cfg = await preprocess(cfg, dataset=dataset)
    lr_schedule = await build_schedule(preprocessed_cfg)
    result_cfg, val_loss = await train(preprocessed_cfg, lr_schedule=lr_schedule)
    return await evaluate(result_cfg, val_loss=val_loss)

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(training_pipeline, dataset="s3://my-bucket/imagenet")
    print(f"Run URL: {run.url}")
    print(f"Outputs: {run.outputs()}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/omegaconf/example.py*

For more focused examples such as plain `DictConfig` patterns, advanced `ListConfig` shapes, all `MISSING`/`Enum`/`Path`/`bytes` cases, see the [plugin repository](https://github.com/flyteorg/flyte-sdk/tree/main/plugins/omegaconf/examples).

=== PAGE: https://www.union.ai/docs/v2/union/integrations/openai ===

# OpenAI

The OpenAI plugin provides a drop-in replacement for the [OpenAI Agents SDK](https://openai.github.io/openai-agents-python/) `function_tool` decorator. It lets you use Flyte tasks as tools in agentic workflows so that tool calls run as tracked, reproducible Flyte task executions.

## When to use this plugin

- Building agentic workflows with the OpenAI Agents SDK on Flyte
- You want tool calls to run as Flyte tasks with full observability, retries, and caching
- You want to combine LLM agents with existing Flyte pipelines

## Installation

```bash
pip install flyteplugins-openai
```

Requires `openai-agents >= 0.2.4`.

## Usage

The plugin provides a single decorator, `function_tool`, that wraps Flyte tasks as OpenAI agent tools.

### `function_tool`

When applied to a Flyte task (a function decorated with `@env.task`), `function_tool` makes that task available as an OpenAI `FunctionTool`. The agent can call it like any other tool, and the call executes as a Flyte task.

When applied to a regular function or a `@flyte.trace`-decorated function, it delegates directly to the OpenAI Agents SDK's built-in `function_tool`.

### Basic pattern

1. Define a `TaskEnvironment` with your image and secrets
2. Decorate your task functions with `@function_tool` and `@env.task`
3. Pass the tools to an `Agent`
4. Run the agent from another Flyte task

```python
from agents import Agent, Runner
from flyteplugins.openai.agents import function_tool

env = flyte.TaskEnvironment(
    name="openai_agents",
    resources=flyte.Resources(cpu=1, memory="250Mi"),
    image=flyte.Image.from_uv_script(__file__, name="openai_agents_image"),
    secrets=flyte.Secret("openai_api_key", as_env_var="OPENAI_API_KEY"),
)

@function_tool
@env.task
async def get_weather(city: str) -> Weather:
    """Get the weather for a given city."""
    return Weather(city=city, temperature_range="14-20C", conditions="Sunny")

agent = Agent(
    name="Weather Agent",
    instructions="You are a helpful agent.",
    tools=[get_weather],
)

@env.task
async def main() -> str:
    result = await Runner.run(agent, input="What's the weather in Tokyo?")
    return result.final_output
```

> [!NOTE]
> The docstring on each `@function_tool` task is sent to the LLM as the tool description. Write clear, concise docstrings that describe what the tool does and what its parameters mean.

### Secrets

Store your OpenAI API key as a Flyte secret and expose it as an environment variable:

```python
secrets=flyte.Secret("openai_api_key", as_env_var="OPENAI_API_KEY")
```

## Example

```python
"""OpenAI Agents with Flyte, basic tool example.

Usage:

Create secret:

```
flyte create secret openai_api_key
uv run agents_tools.py
```
"""
# {{docs-fragment uv-script}}

# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "flyteplugins-openai>=2.0.0b7",
#    "openai-agents>=0.2.4",
#    "pydantic>=2.10.6",
# ]
# main = "main"
# params = ""
# ///

# {{/docs-fragment uv-script}}

# {{docs-fragment imports-task-env}}
from agents import Agent, Runner
from pydantic import BaseModel

import flyte
from flyteplugins.openai.agents import function_tool

env = flyte.TaskEnvironment(
    name="openai_agents_tools",
    resources=flyte.Resources(cpu=1, memory="250Mi"),
    image=flyte.Image.from_uv_script(__file__, name="openai_agents_image"),
    secrets=flyte.Secret("openai_api_key", as_env_var="OPENAI_API_KEY"),
)

# {{/docs-fragment imports-task-env}}

# {{docs-fragment tools}}
class Weather(BaseModel):
    city: str
    temperature_range: str
    conditions: str

@function_tool
@env.task
async def get_weather(city: str) -> Weather:
    """Get the weather for a given city."""
    return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")

# {{/docs-fragment tools}}

# {{docs-fragment agent}}
agent = Agent(
    name="Hello world",
    instructions="You are a helpful agent.",
    tools=[get_weather],
)

@env.task
async def main() -> str:
    result = await Runner.run(agent, input="What's the weather in Tokyo?")
    print(result.final_output)
    return result.final_output

# {{/docs-fragment agent}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/openai/openai/agents_tools.py*

## API reference

See the [OpenAI API reference](https://www.union.ai/docs/v2/union/api-reference/integrations/openai/_index) for full details.

=== PAGE: https://www.union.ai/docs/v2/union/integrations/openai/agent_tools ===

# Agent tools

In this example, we will use the `openai-agents` library to create a simple agent that can use tools to perform tasks.
This example is based on the [basic tools example](https://github.com/openai/openai-agents-python/blob/main/examples/basic/tools.py) example from the `openai-agents-python` repo.

First, create an OpenAI API key, which you can get from the [OpenAI website](https://platform.openai.com/account/api-keys).
Then, create a secret on your Flyte cluster with:

```
flyte create secret OPENAI_API_KEY --value <your-api-key>
```

Then, we'll use `uv script` to specify our dependencies.

```
"""OpenAI Agents with Flyte, basic tool example.

Usage:

Create secret:

```
flyte create secret openai_api_key
uv run agents_tools.py
```
"""
# {{docs-fragment uv-script}}

# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "flyteplugins-openai>=2.0.0b7",
#    "openai-agents>=0.2.4",
#    "pydantic>=2.10.6",
# ]
# main = "main"
# params = ""
# ///

# {{/docs-fragment uv-script}}

# {{docs-fragment imports-task-env}}
from agents import Agent, Runner
from pydantic import BaseModel

import flyte
from flyteplugins.openai.agents import function_tool

env = flyte.TaskEnvironment(
    name="openai_agents_tools",
    resources=flyte.Resources(cpu=1, memory="250Mi"),
    image=flyte.Image.from_uv_script(__file__, name="openai_agents_image"),
    secrets=flyte.Secret("openai_api_key", as_env_var="OPENAI_API_KEY"),
)

# {{/docs-fragment imports-task-env}}

# {{docs-fragment tools}}
class Weather(BaseModel):
    city: str
    temperature_range: str
    conditions: str

@function_tool
@env.task
async def get_weather(city: str) -> Weather:
    """Get the weather for a given city."""
    return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")

# {{/docs-fragment tools}}

# {{docs-fragment agent}}
agent = Agent(
    name="Hello world",
    instructions="You are a helpful agent.",
    tools=[get_weather],
)

@env.task
async def main() -> str:
    result = await Runner.run(agent, input="What's the weather in Tokyo?")
    print(result.final_output)
    return result.final_output

# {{/docs-fragment agent}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/openai/openai/agents_tools.py*

Next, we'll import the libraries and create a `TaskEnvironment`, which we need to run the example:

```
"""OpenAI Agents with Flyte, basic tool example.

Usage:

Create secret:

```
flyte create secret openai_api_key
uv run agents_tools.py
```
"""
# {{docs-fragment uv-script}}

# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "flyteplugins-openai>=2.0.0b7",
#    "openai-agents>=0.2.4",
#    "pydantic>=2.10.6",
# ]
# main = "main"
# params = ""
# ///

# {{/docs-fragment uv-script}}

# {{docs-fragment imports-task-env}}
from agents import Agent, Runner
from pydantic import BaseModel

import flyte
from flyteplugins.openai.agents import function_tool

env = flyte.TaskEnvironment(
    name="openai_agents_tools",
    resources=flyte.Resources(cpu=1, memory="250Mi"),
    image=flyte.Image.from_uv_script(__file__, name="openai_agents_image"),
    secrets=flyte.Secret("openai_api_key", as_env_var="OPENAI_API_KEY"),
)

# {{/docs-fragment imports-task-env}}

# {{docs-fragment tools}}
class Weather(BaseModel):
    city: str
    temperature_range: str
    conditions: str

@function_tool
@env.task
async def get_weather(city: str) -> Weather:
    """Get the weather for a given city."""
    return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")

# {{/docs-fragment tools}}

# {{docs-fragment agent}}
agent = Agent(
    name="Hello world",
    instructions="You are a helpful agent.",
    tools=[get_weather],
)

@env.task
async def main() -> str:
    result = await Runner.run(agent, input="What's the weather in Tokyo?")
    print(result.final_output)
    return result.final_output

# {{/docs-fragment agent}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/openai/openai/agents_tools.py*

## Define the tools

We'll define a tool that can get weather information for a
given city. In this case, we'll use a toy function that returns a hard-coded `Weather` object.

```
"""OpenAI Agents with Flyte, basic tool example.

Usage:

Create secret:

```
flyte create secret openai_api_key
uv run agents_tools.py
```
"""
# {{docs-fragment uv-script}}

# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "flyteplugins-openai>=2.0.0b7",
#    "openai-agents>=0.2.4",
#    "pydantic>=2.10.6",
# ]
# main = "main"
# params = ""
# ///

# {{/docs-fragment uv-script}}

# {{docs-fragment imports-task-env}}
from agents import Agent, Runner
from pydantic import BaseModel

import flyte
from flyteplugins.openai.agents import function_tool

env = flyte.TaskEnvironment(
    name="openai_agents_tools",
    resources=flyte.Resources(cpu=1, memory="250Mi"),
    image=flyte.Image.from_uv_script(__file__, name="openai_agents_image"),
    secrets=flyte.Secret("openai_api_key", as_env_var="OPENAI_API_KEY"),
)

# {{/docs-fragment imports-task-env}}

# {{docs-fragment tools}}
class Weather(BaseModel):
    city: str
    temperature_range: str
    conditions: str

@function_tool
@env.task
async def get_weather(city: str) -> Weather:
    """Get the weather for a given city."""
    return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")

# {{/docs-fragment tools}}

# {{docs-fragment agent}}
agent = Agent(
    name="Hello world",
    instructions="You are a helpful agent.",
    tools=[get_weather],
)

@env.task
async def main() -> str:
    result = await Runner.run(agent, input="What's the weather in Tokyo?")
    print(result.final_output)
    return result.final_output

# {{/docs-fragment agent}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/openai/openai/agents_tools.py*

In this code snippet, the `@function_tool` decorator is imported from `flyteplugins.openai.agents`, which is a drop-in replacement for the `@function_tool` decorator from `openai-agents` library.

## Define the agent

Then, we'll define the agent, which calls the tool:

```
"""OpenAI Agents with Flyte, basic tool example.

Usage:

Create secret:

```
flyte create secret openai_api_key
uv run agents_tools.py
```
"""
# {{docs-fragment uv-script}}

# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "flyteplugins-openai>=2.0.0b7",
#    "openai-agents>=0.2.4",
#    "pydantic>=2.10.6",
# ]
# main = "main"
# params = ""
# ///

# {{/docs-fragment uv-script}}

# {{docs-fragment imports-task-env}}
from agents import Agent, Runner
from pydantic import BaseModel

import flyte
from flyteplugins.openai.agents import function_tool

env = flyte.TaskEnvironment(
    name="openai_agents_tools",
    resources=flyte.Resources(cpu=1, memory="250Mi"),
    image=flyte.Image.from_uv_script(__file__, name="openai_agents_image"),
    secrets=flyte.Secret("openai_api_key", as_env_var="OPENAI_API_KEY"),
)

# {{/docs-fragment imports-task-env}}

# {{docs-fragment tools}}
class Weather(BaseModel):
    city: str
    temperature_range: str
    conditions: str

@function_tool
@env.task
async def get_weather(city: str) -> Weather:
    """Get the weather for a given city."""
    return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")

# {{/docs-fragment tools}}

# {{docs-fragment agent}}
agent = Agent(
    name="Hello world",
    instructions="You are a helpful agent.",
    tools=[get_weather],
)

@env.task
async def main() -> str:
    result = await Runner.run(agent, input="What's the weather in Tokyo?")
    print(result.final_output)
    return result.final_output

# {{/docs-fragment agent}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/openai/openai/agents_tools.py*

## Run the agent

Finally, we'll run the agent. Create `config.yaml` file, which the `flyte.init_from_config()` function will use to connect to
the Flyte cluster:

```bash
flyte create config \
--output ~/.flyte/config.yaml \
--endpoint demo.hosted.unionai.cloud/ \
--project flytesnacks \
--domain development \
--builder remote
```

```
"""OpenAI Agents with Flyte, basic tool example.

Usage:

Create secret:

```
flyte create secret openai_api_key
uv run agents_tools.py
```
"""
# {{docs-fragment uv-script}}

# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "flyteplugins-openai>=2.0.0b7",
#    "openai-agents>=0.2.4",
#    "pydantic>=2.10.6",
# ]
# main = "main"
# params = ""
# ///

# {{/docs-fragment uv-script}}

# {{docs-fragment imports-task-env}}
from agents import Agent, Runner
from pydantic import BaseModel

import flyte
from flyteplugins.openai.agents import function_tool

env = flyte.TaskEnvironment(
    name="openai_agents_tools",
    resources=flyte.Resources(cpu=1, memory="250Mi"),
    image=flyte.Image.from_uv_script(__file__, name="openai_agents_image"),
    secrets=flyte.Secret("openai_api_key", as_env_var="OPENAI_API_KEY"),
)

# {{/docs-fragment imports-task-env}}

# {{docs-fragment tools}}
class Weather(BaseModel):
    city: str
    temperature_range: str
    conditions: str

@function_tool
@env.task
async def get_weather(city: str) -> Weather:
    """Get the weather for a given city."""
    return Weather(city=city, temperature_range="14-20C", conditions="Sunny with wind.")

# {{/docs-fragment tools}}

# {{docs-fragment agent}}
agent = Agent(
    name="Hello world",
    instructions="You are a helpful agent.",
    tools=[get_weather],
)

@env.task
async def main() -> str:
    result = await Runner.run(agent, input="What's the weather in Tokyo?")
    print(result.final_output)
    return result.final_output

# {{/docs-fragment agent}}

# {{docs-fragment main}}
if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()
# {{/docs-fragment main}}
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/openai/openai/agents_tools.py*

## Conclusion

In this example, we've seen how to use the `openai-agents` library to create a simple agent that can use tools to perform tasks.

The full code is available [here](https://github.com/unionai/unionai-examples/tree/main/v2/integrations/flyte-plugins/openai/openai).

=== PAGE: https://www.union.ai/docs/v2/union/integrations/pandera ===

# Pandera

The [Pandera](https://pandera.readthedocs.io/en/latest/) plugin validates dataframes at task boundaries using
[`DataFrameModel`](https://pandera.readthedocs.io/en/latest/dataframe_models.html) schemas. When a task receives or
returns a pandera-typed dataframe, the plugin automatically validates the data, raises or warns on schema violations,
and writes an HTML validation report to the Flyte deck.

Pandera supports multiple dataframe backends. The `flyteplugins-pandera` plugin handles:

| Pandera typing module | DataFrame library | Additional plugin |
|-|-|-|
| `pandera.typing.pandas` | pandas | — |
| `pandera.typing.polars` | Polars (eager and lazy) | `flyteplugins-polars` |
| `pandera.typing.pyspark_sql` | PySpark SQL | `flyteplugins-spark` |

## When to use this plugin

- You want compile-time-style guarantees that data flowing between tasks conforms to a declared schema
- You need column-level type, constraint, and statistical checks on task inputs and outputs
- You want automatic validation reports visible in the Flyte UI

## Installation

Install the plugin with the pandera extras for your dataframe backend:

### pandas

```bash
pip install flyteplugins-pandera 'pandera[pandas]'
```

### Polars

```bash
pip install flyteplugins-pandera flyteplugins-polars 'pandera[polars]'
```

### PySpark SQL

```bash
pip install flyteplugins-pandera flyteplugins-spark 'pandera[pyspark]'
```

## Defining schemas

Schemas are defined as Python classes that inherit from pandera's `DataFrameModel`. Each field declares a column name,
type, and optional constraints:

```python
import pandera.pandas as pa

class EmployeeSchema(pa.DataFrameModel):
    employee_id: int = pa.Field(ge=0)
    name: str

class EmployeeSchemaWithStatus(EmployeeSchema):
    status: str = pa.Field(isin=["active", "inactive"])
```

Schemas compose through inheritance: `EmployeeSchemaWithStatus` includes all columns from `EmployeeSchema` plus the
`status` column.

For full details on schema definition—including custom checks, regex column matching, and `Config` options—see the
[pandera DataFrameModel documentation](https://pandera.readthedocs.io/en/latest/dataframe_models.html).

## Using schemas in tasks

Annotate task inputs and outputs with pandera's generic `DataFrame` type. The plugin validates data on every
encode (output) and decode (input):

```python
import pandera.typing.pandas as pt

@env.task(report=True)
async def build_employees() -> pt.DataFrame[EmployeeSchema]:
    return pd.DataFrame({
        "employee_id": [1, 2, 3],
        "name": ["Ada", "Grace", "Barbara"],
    })

@env.task(report=True)
async def add_status(
    df: pt.DataFrame[EmployeeSchema],
) -> pt.DataFrame[EmployeeSchemaWithStatus]:
    return df.assign(status="active")
```

Setting `report=True` on the task makes validation reports visible as deck tabs in the Flyte UI.

## Error handling with `ValidationConfig`

By default, a validation failure raises an exception and fails the task. To downgrade failures to warnings instead,
annotate the parameter with `ValidationConfig(on_error="warn")`:

```python
from typing import Annotated
from flyteplugins.pandera import ValidationConfig

@env.task(report=True)
async def lenient_pass_through(
    df: Annotated[pt.DataFrame[EmployeeSchema], ValidationConfig(on_error="warn")],
) -> Annotated[pt.DataFrame[EmployeeSchemaWithStatus], ValidationConfig(on_error="warn")]:
    ...
```

| `on_error` value | Behavior |
|-|-|
| `"raise"` (default) | Validation failure raises `pandera.errors.SchemaError` and the task fails |
| `"warn"` | Validation failure logs a warning and writes the report, but the task continues |

You can mix `"raise"` and `"warn"` across inputs and outputs of the same task. For example, use `"warn"` on inputs
to accept best-effort data while still enforcing strict output contracts.

## Image configuration

Include the plugin in your task image. The exact setup depends on your dataframe backend:

### Pandas

```python
import flyte

img = flyte.Image.from_debian_base(
    python_version=(3, 12),
).with_pip_packages("flyteplugins-pandera")

env = flyte.TaskEnvironment(
    "pandera_pandas",
    image=img,
    resources=flyte.Resources(cpu="1", memory="2Gi"),
)
```

### Polars

```python
import flyte

img = (
    flyte.Image.from_debian_base(python_version=(3, 12))
    .with_pip_packages("flyteplugins-polars", "pandera[polars]")
)

env = flyte.TaskEnvironment(
    "pandera_polars",
    image=img,
    resources=flyte.Resources(cpu="1", memory="2Gi"),
)
```

### PySpark SQL

```python
import flyte
from flyteplugins.spark.task import Spark

image = (
    flyte.Image.from_base("apache/spark-py:v3.4.0")
    .clone(name="pandera-pyspark-sql", python_version=(3, 10), extendable=True)
    .with_pip_packages("flyteplugins-spark", "pandera[pyspark]")
)

spark_conf = Spark(
    spark_conf={
        "spark.driver.memory": "1000M",
        "spark.executor.memory": "1000M",
        "spark.executor.cores": "1",
        "spark.executor.instances": "2",
        "spark.driver.cores": "1",
    },
)

env = flyte.TaskEnvironment(
    name="pandera_pyspark",
    plugin_config=spark_conf,
    image=image,
    resources=flyte.Resources(cpu="1", memory="2Gi"),
)
```

## Polars lazy frames

The Polars backend supports both `pt.DataFrame` (eager) and `pt.LazyFrame` (lazy). With lazy frames, pandera
validates the data when the frame is materialized at task I/O boundaries:

```python
import pandera.typing.polars as pt
import polars as pl

@env.task(report=True)
async def create_lazy() -> pt.LazyFrame[MetricsSchema]:
    return pl.LazyFrame({"item": ["x", "y"], "value": [3.0, 4.0]})

@env.task(report=True)
async def consume_lazy(
    lf: pt.LazyFrame[MetricsSchema],
) -> pt.DataFrame[MetricsSchema]:
    return lf.filter(pl.col("value") > 0.0).collect()
```

## Examples

### pandas

```python
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte",
#    "flyteplugins-pandera",
#    "pandera[pandas]",
# ]
# main = "main"
# ///

from __future__ import annotations

from typing import Annotated

import pandas as pd
import pandera.pandas as pa
import pandera.typing.pandas as pt
from flyteplugins.pandera import ValidationConfig

import flyte

img = flyte.Image.from_debian_base(python_version=(3, 12)).with_pip_packages(
    "flyteplugins-pandera", "pandera[pandas]"
)

env = flyte.TaskEnvironment(
    "pandera_pandas_schema",
    image=img,
    resources=flyte.Resources(cpu="1", memory="2Gi"),
)

class EmployeeSchema(pa.DataFrameModel):
    employee_id: int = pa.Field(ge=0)
    name: str

class EmployeeSchemaWithStatus(EmployeeSchema):
    status: str = pa.Field(isin=["active", "inactive"])

# {{docs-fragment build_valid_employees}}
@env.task(report=True)
async def build_valid_employees() -> pt.DataFrame[EmployeeSchema]:
    return pd.DataFrame(
        {
            "employee_id": [1, 2, 3],
            "name": ["Ada", "Grace", "Barbara"],
        }
    )
# {{/docs-fragment}}

# {{docs-fragment pass_through}}
@env.task(report=True)
async def pass_through(
    df: pt.DataFrame[EmployeeSchema],
) -> pt.DataFrame[EmployeeSchemaWithStatus]:
    return df.assign(status="active")
# {{/docs-fragment}}

# {{docs-fragment pass_through_with_error_warn}}
@env.task(report=True)
async def pass_through_with_error_warn(
    df: Annotated[
        pt.DataFrame[EmployeeSchema], ValidationConfig(on_error="warn")
    ],
) -> Annotated[
    pt.DataFrame[EmployeeSchemaWithStatus], ValidationConfig(on_error="warn")
]:
    del df["name"]
    return df
# {{/docs-fragment}}

# {{docs-fragment pass_through_with_error_raise}}
@env.task(report=True)
async def pass_through_with_error_raise(
    df: Annotated[
        pt.DataFrame[EmployeeSchema], ValidationConfig(on_error="warn")
    ],
) -> Annotated[
    pt.DataFrame[EmployeeSchemaWithStatus], ValidationConfig(on_error="raise")
]:
    del df["name"]
    return df
# {{/docs-fragment}}

@env.task(report=True)
async def main() -> pt.DataFrame[EmployeeSchemaWithStatus]:
    df = await build_valid_employees()
    df2 = await pass_through(df)

    await pass_through_with_error_warn(df.drop(["employee_id"], axis="columns"))
    await pass_through_with_error_warn(df.assign(employee_id=-1))

    try:
        await pass_through_with_error_raise(df)
    except Exception as exc:
        print(exc)

    return df2

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()
    print("pandas pandera example OK:", run.outputs()[0])
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/pandera/pandas_schema.py*

### Polars

```python
# /// script
# requires-python = ">=3.12"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "flyteplugins-pandera",
#    "flyteplugins-polars",
#    "pandera[polars]",
# ]
# main = "main"
# ///

from __future__ import annotations

from typing import Annotated

import pandera.polars as pa
import pandera.typing.polars as pt
import polars as pl
from flyteplugins.pandera import ValidationConfig

import flyte

img = (
    flyte.Image.from_debian_base(python_version=(3, 12))
    .with_pip_packages("flyteplugins-pandera", "flyteplugins-polars", "pandera[polars]")
)

env = flyte.TaskEnvironment(
    "pandera_polars_schema",
    image=img,
    resources=flyte.Resources(cpu="1", memory="2Gi"),
)

class EmployeeSchema(pa.DataFrameModel):
    employee_id: int = pa.Field(ge=0)
    name: str

class EmployeeSchemaWithStatus(EmployeeSchema):
    status: str = pa.Field(isin=["active", "inactive"])

class MetricsSchema(pa.DataFrameModel):
    item: str
    value: float

# {{docs-fragment build_valid_employees}}
@env.task(report=True)
async def build_valid_employees() -> pt.DataFrame[EmployeeSchema]:
    return pl.DataFrame(
        {
            "employee_id": [1, 2, 3],
            "name": ["Ada", "Grace", "Barbara"],
        }
    )
# {{/docs-fragment}}

# {{docs-fragment pass_through}}
@env.task(report=True)
async def pass_through(
    df: pt.DataFrame[EmployeeSchema],
) -> pt.DataFrame[EmployeeSchemaWithStatus]:
    return df.with_columns(pl.lit("active").alias("status"))
# {{/docs-fragment}}

@env.task(report=True)
async def pass_through_with_error_warn(
    df: Annotated[
        pt.DataFrame[EmployeeSchema], ValidationConfig(on_error="warn")
    ],
) -> Annotated[
    pt.DataFrame[EmployeeSchemaWithStatus], ValidationConfig(on_error="warn")
]:
    return df.drop("name")

@env.task(report=True)
async def pass_through_with_error_raise(
    df: Annotated[
        pt.DataFrame[EmployeeSchema], ValidationConfig(on_error="warn")
    ],
) -> Annotated[
    pt.DataFrame[EmployeeSchemaWithStatus], ValidationConfig(on_error="raise")
]:
    return df.drop("name")

# {{docs-fragment metrics_lazy}}
@env.task(report=True)
async def metrics_eager() -> pt.DataFrame[MetricsSchema]:
    return pl.DataFrame({"item": ["a", "b"], "value": [1.0, 2.0]})

@env.task(report=True)
async def metrics_lazy() -> pt.LazyFrame[MetricsSchema]:
    return pl.LazyFrame({"item": ["x", "y"], "value": [3.0, 4.0]})

@env.task(report=True)
async def filter_metrics(
    lf: pt.LazyFrame[MetricsSchema],
) -> pt.DataFrame[MetricsSchema]:
    return lf.filter(pl.col("value") > 0.0).collect()
# {{/docs-fragment}}

@env.task(report=True)
async def main() -> pt.DataFrame[EmployeeSchemaWithStatus]:
    df = await build_valid_employees()
    df2 = await pass_through(df)

    await pass_through_with_error_warn(df.drop("employee_id"))
    await pass_through_with_error_warn(
        df.with_columns(pl.lit(-1).alias("employee_id"))
    )

    try:
        await pass_through_with_error_raise(df)
    except Exception as exc:
        print(exc)

    _ = await metrics_eager()
    lazy = await metrics_lazy()
    _ = await filter_metrics(lazy)

    return df2

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()
    print("polars pandera example OK:", run.outputs()[0])
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/pandera/polars_schema.py*

### PySpark SQL

```python
# /// script
# requires-python = ">=3.10"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "flyteplugins-pandera",
#    "flyteplugins-spark",
#    "pandera[pyspark]",
# ]
# main = "main"
# ///

from __future__ import annotations

from typing import Annotated, cast

import pandera.typing.pyspark_sql as pt
import pyspark.sql.types as T
from flyteplugins.pandera import ValidationConfig
from flyteplugins.spark.task import Spark
from pandera.pyspark import DataFrameModel, Field
from pyspark.sql import SparkSession
from pyspark.sql import functions as F

import flyte

image = (
    flyte.Image.from_base("apache/spark-py:v3.4.0")
    .clone(name="pandera-pyspark-sql", python_version=(3, 10), extendable=True)
    .with_pip_packages(
        "flyteplugins-pandera",
        "flyteplugins-spark",
        "pandera[pyspark]",
    )
)

spark_conf = Spark(
    spark_conf={
        "spark.driver.memory": "1000M",
        "spark.executor.memory": "1000M",
        "spark.executor.cores": "1",
        "spark.executor.instances": "2",
        "spark.driver.cores": "1",
        "spark.kubernetes.file.upload.path": "/opt/spark/work-dir",
        "spark.jars": (
            "https://storage.googleapis.com/hadoop-lib/gcs/"
            "gcs-connector-hadoop3-latest.jar,"
            "https://repo1.maven.org/maven2/org/apache/hadoop/"
            "hadoop-aws/3.2.2/hadoop-aws-3.2.2.jar,"
            "https://repo1.maven.org/maven2/com/amazonaws/"
            "aws-java-sdk-bundle/1.12.262/aws-java-sdk-bundle-1.12.262.jar"
        ),
    },
)

env = flyte.TaskEnvironment(
    name="pandera_pyspark_sql_schema",
    plugin_config=spark_conf,
    image=image,
    resources=flyte.Resources(cpu="1", memory="2Gi"),
)

# {{docs-fragment schemas}}
class EmployeeSchema(DataFrameModel):
    employee_id: int = Field(ge=0)
    name: str = Field()
    job_title: str = Field()

class EmployeeSchemaWithStatus(EmployeeSchema):
    status: str = Field(isin=["active", "inactive"])
# {{/docs-fragment}}

# {{docs-fragment build_valid_employees}}
@env.task(report=True)
async def build_valid_employees() -> pt.DataFrame[EmployeeSchema]:
    spark = cast(SparkSession, flyte.ctx().data["spark_session"])
    data = [
        (1, "Ada", "Engineer"),
        (2, "Grace", "Mathematician"),
        (3, "Barbara", "Computer scientist"),
    ]
    schema = T.StructType(
        [
            T.StructField("employee_id", T.IntegerType(), False),
            T.StructField("name", T.StringType(), False),
            T.StructField("job_title", T.StringType(), False),
        ]
    )
    return spark.createDataFrame(data, schema=schema)
# {{/docs-fragment}}

# {{docs-fragment pass_through}}
@env.task(report=True)
async def pass_through(
    df: pt.DataFrame[EmployeeSchema],
) -> pt.DataFrame[EmployeeSchemaWithStatus]:
    return df.withColumn("status", F.lit("active"))
# {{/docs-fragment}}

@env.task(report=True)
async def pass_through_with_error_warn(
    df: Annotated[
        pt.DataFrame[EmployeeSchema], ValidationConfig(on_error="warn")
    ],
) -> Annotated[
    pt.DataFrame[EmployeeSchemaWithStatus], ValidationConfig(on_error="warn")
]:
    return df.drop("name")

@env.task(report=True)
async def pass_through_with_error_raise(
    df: Annotated[
        pt.DataFrame[EmployeeSchema], ValidationConfig(on_error="warn")
    ],
) -> Annotated[
    pt.DataFrame[EmployeeSchemaWithStatus], ValidationConfig(on_error="raise")
]:
    return df.drop("name")

@env.task(report=True)
async def main() -> pt.DataFrame[EmployeeSchemaWithStatus]:
    df = await build_valid_employees()
    df2 = await pass_through(df)

    await pass_through_with_error_warn(df.drop("employee_id"))
    await pass_through_with_error_warn(df.withColumn("employee_id", F.lit(-1)))

    try:
        await pass_through_with_error_raise(df)
    except Exception as exc:
        print(exc)

    return df2

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(main)
    print(run.url)
    run.wait()
    print("pyspark_sql pandera example OK:", run.outputs()[0])
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/pandera/pyspark_sql_schema.py*

=== PAGE: https://www.union.ai/docs/v2/union/integrations/pytorch ===

# PyTorch

The PyTorch plugin lets you run distributed [PyTorch](https://pytorch.org/) training jobs natively on Kubernetes. It uses the [Kubeflow Training Operator](https://github.com/kubeflow/training-operator) to manage multi-node training with PyTorch's elastic launch (`torchrun`).

## When to use this plugin

- Single-node or multi-node distributed training with `DistributedDataParallel` (DDP)
- Elastic training that can scale up and down during execution
- Any workload that uses `torch.distributed` for data-parallel or model-parallel training

## Installation

```bash
pip install flyteplugins-pytorch
```

## Configuration

Create an `Elastic` configuration and pass it as `plugin_config` to a `TaskEnvironment`:

```python
from flyteplugins.pytorch import Elastic

torch_env = flyte.TaskEnvironment(
    name="torch_env",
    resources=flyte.Resources(cpu=(1, 2), memory=("1Gi", "2Gi")),
    plugin_config=Elastic(
        nnodes=2,
        nproc_per_node=1,
    ),
    image=image,
)
```

### `Elastic` parameters

| Parameter | Type | Description |
|-----------|------|-------------|
| `nnodes` | `int` or `str` | **Required.** Number of nodes. Use an int for a fixed count or a range string (e.g., `"2:4"`) for elastic training |
| `nproc_per_node` | `int` | **Required.** Number of processes (workers) per node |
| `rdzv_backend` | `str` | Rendezvous backend: `"c10d"` (default), `"etcd"`, or `"etcd-v2"` |
| `max_restarts` | `int` | Maximum worker group restarts (default: `3`) |
| `monitor_interval` | `int` | Agent health check interval in seconds (default: `3`) |
| `run_policy` | `RunPolicy` | Job run policy (cleanup, TTL, deadlines, retries) |

### `RunPolicy` parameters

| Parameter | Type | Description |
|-----------|------|-------------|
| `clean_pod_policy` | `str` | Pod cleanup policy: `"None"`, `"all"`, or `"Running"` |
| `ttl_seconds_after_finished` | `int` | Seconds to keep pods after job completion |
| `active_deadline_seconds` | `int` | Maximum time the job can run (seconds) |
| `backoff_limit` | `int` | Number of retries before marking the job as failed |

### NCCL tuning parameters

The plugin includes built-in NCCL timeout tuning to reduce failure-detection latency (PyTorch defaults to 1800 seconds):

| Parameter | Type | Default | Description |
|-----------|------|---------|-------------|
| `nccl_heartbeat_timeout_sec` | `int` | `300` | NCCL heartbeat timeout (seconds) |
| `nccl_async_error_handling` | `bool` | `False` | Enable async NCCL error handling |
| `nccl_collective_timeout_sec` | `int` | `None` | Timeout for NCCL collective operations |
| `nccl_enable_monitoring` | `bool` | `True` | Enable NCCL monitoring |

### Writing a distributed training task

Tasks using this plugin do not need to be `async`. Initialize the process group and use `DistributedDataParallel` as you normally would with `torchrun`:

```python
import torch
import torch.distributed
from torch.nn.parallel import DistributedDataParallel as DDP

@torch_env.task
def train(epochs: int) -> float:
    torch.distributed.init_process_group("gloo")
    model = DDP(MyModel())
    # ... training loop ...
    return final_loss
```

> [!NOTE]
> When `nnodes=1`, the task runs as a regular Python task (no Kubernetes training job is created). Set `nnodes >= 2` for multi-node distributed training.

## Example

```python
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "flyteplugins-pytorch",
#    "torch"
# ]
# main = "torch_distributed_train"
# params = "3"
# ///

import typing

import torch
import torch.distributed
import torch.nn as nn
import torch.optim as optim
from flyteplugins.pytorch.task import Elastic
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler, TensorDataset

import flyte

image = flyte.Image.from_debian_base(name="torch").with_pip_packages("flyteplugins-pytorch", pre=True)

torch_env = flyte.TaskEnvironment(
    name="torch_env",
    resources=flyte.Resources(cpu=(1, 2), memory=("1Gi", "2Gi")),
    plugin_config=Elastic(
        nproc_per_node=1,
        # if you want to do local testing set nnodes=1
        nnodes=2,
    ),
    image=image,
)

class LinearRegressionModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(1, 1)

    def forward(self, x):
        return self.linear(x)

def prepare_dataloader(rank: int, world_size: int, batch_size: int = 2) -> DataLoader:
    """
    Prepare a DataLoader with a DistributedSampler so each rank
    gets a shard of the dataset.
    """
    # Dummy dataset
    x_train = torch.tensor([[1.0], [2.0], [3.0], [4.0]])
    y_train = torch.tensor([[3.0], [5.0], [7.0], [9.0]])
    dataset = TensorDataset(x_train, y_train)

    # Distributed-aware sampler
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank, shuffle=True)

    return DataLoader(dataset, batch_size=batch_size, sampler=sampler)

def train_loop(epochs: int = 3) -> float:
    """
    A simple training loop for linear regression.
    """
    torch.distributed.init_process_group("gloo")
    model = DDP(LinearRegressionModel())

    rank = torch.distributed.get_rank()
    world_size = torch.distributed.get_world_size()

    dataloader = prepare_dataloader(
        rank=rank,
        world_size=world_size,
        batch_size=64,
    )

    criterion = nn.MSELoss()
    optimizer = optim.SGD(model.parameters(), lr=0.01)

    final_loss = 0.0

    for _ in range(epochs):
        for x, y in dataloader:
            outputs = model(x)
            loss = criterion(outputs, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            final_loss = loss.item()
        if torch.distributed.get_rank() == 0:
            print(f"Loss: {final_loss}")

    return final_loss

@torch_env.task
def torch_distributed_train(epochs: int) -> typing.Optional[float]:
    """
    A nested task that sets up a simple distributed training job using PyTorch's
    """
    print("starting launcher")
    loss = train_loop(epochs=epochs)
    print("Training complete")
    return loss

if __name__ == "__main__":
    flyte.init_from_config()
    r = flyte.run(torch_distributed_train, epochs=3)
    print(r.name)
    print(r.url)
    r.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/pytorch/pytorch_example.py*

## API reference

See the [PyTorch API reference](https://www.union.ai/docs/v2/union/api-reference/integrations/pytorch/_index) for full details.

=== PAGE: https://www.union.ai/docs/v2/union/integrations/ray ===

# Ray

The Ray plugin lets you run [Ray](https://www.ray.io/) jobs natively on Kubernetes. Flyte provisions a transient Ray cluster for each task execution using [KubeRay](https://github.com/ray-project/kuberay) and tears it down on completion.

## When to use this plugin

- Distributed Python workloads (parallel computation, data processing)
- ML training with Ray Train or hyperparameter tuning with Ray Tune
- Ray Serve inference workloads
- Any workload that benefits from Ray's actor model or task parallelism

## Installation

```bash
pip install flyteplugins-ray
```

Your task image must also include a compatible version of Ray:

```python
image = (
    flyte.Image.from_debian_base(name="ray")
    .with_pip_packages("ray[default]==2.46.0", "flyteplugins-ray")
)
```

## Configuration

Create a `RayJobConfig` and pass it as `plugin_config` to a `TaskEnvironment`:

```python
from flyteplugins.ray import HeadNodeConfig, RayJobConfig, WorkerNodeConfig

ray_config = RayJobConfig(
    head_node_config=HeadNodeConfig(ray_start_params={"log-color": "True"}),
    worker_node_config=[WorkerNodeConfig(group_name="ray-group", replicas=2)],
    runtime_env={"pip": ["numpy", "pandas"]},
    enable_autoscaling=False,
    shutdown_after_job_finishes=True,
    ttl_seconds_after_finished=300,
)

ray_env = flyte.TaskEnvironment(
    name="ray_env",
    plugin_config=ray_config,
    image=image,
)
```

### `RayJobConfig` parameters

| Parameter | Type | Description |
|-----------|------|-------------|
| `worker_node_config` | `List[WorkerNodeConfig]` | **Required.** List of worker group configurations |
| `head_node_config` | `HeadNodeConfig` | Head node configuration (optional) |
| `enable_autoscaling` | `bool` | Enable Ray autoscaler (default: `False`) |
| `runtime_env` | `dict` | Ray runtime environment (pip packages, env vars, etc.) |
| `address` | `str` | Connect to an existing Ray cluster instead of provisioning one |
| `shutdown_after_job_finishes` | `bool` | Shut down the cluster after the job completes (default: `False`) |
| `ttl_seconds_after_finished` | `int` | Seconds to keep the cluster after completion before cleanup |

### `WorkerNodeConfig` parameters

| Parameter | Type | Description |
|-----------|------|-------------|
| `group_name` | `str` | **Required.** Name of this worker group |
| `replicas` | `int` | **Required.** Number of worker replicas |
| `min_replicas` | `int` | Minimum replicas (for autoscaling) |
| `max_replicas` | `int` | Maximum replicas (for autoscaling) |
| `ray_start_params` | `Dict[str, str]` | Ray start parameters for workers |
| `requests` | `Resources` | Resource requests per worker |
| `limits` | `Resources` | Resource limits per worker |
| `pod_template` | `PodTemplate` | Full pod template (mutually exclusive with `requests`/`limits`) |

### `HeadNodeConfig` parameters

| Parameter | Type | Description |
|-----------|------|-------------|
| `ray_start_params` | `Dict[str, str]` | Ray start parameters for the head node |
| `requests` | `Resources` | Resource requests for the head node |
| `limits` | `Resources` | Resource limits for the head node |
| `pod_template` | `PodTemplate` | Full pod template (mutually exclusive with `requests`/`limits`) |

### Connecting to an existing cluster

To connect to an existing Ray cluster instead of provisioning a new one, set the `address` parameter:

```python
ray_config = RayJobConfig(
    worker_node_config=[WorkerNodeConfig(group_name="ray-group", replicas=2)],
    address="ray://existing-cluster:10001",
)
```

## Examples

The following example shows how to configure Ray in a `TaskEnvironment`. Flyte automatically provisions a Ray cluster for each task using this configuration:

```python
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "flyteplugins-ray",
#    "ray[default]==2.46.0"
# ]
# main = "hello_ray_nested"
# params = "3"
# ///

import asyncio
import typing

import ray
from flyteplugins.ray.task import HeadNodeConfig, RayJobConfig, WorkerNodeConfig

import flyte.remote
import flyte.storage

@ray.remote
def f(x):
    return x * x

ray_config = RayJobConfig(
    head_node_config=HeadNodeConfig(ray_start_params={"log-color": "True"}),
    worker_node_config=[WorkerNodeConfig(group_name="ray-group", replicas=2)],
    runtime_env={"pip": ["numpy", "pandas"]},
    enable_autoscaling=False,
    shutdown_after_job_finishes=True,
    ttl_seconds_after_finished=300,
)

image = (
    flyte.Image.from_debian_base(name="ray")
    .with_apt_packages("wget")
    .with_pip_packages("ray[default]==2.46.0", "flyteplugins-ray", "pip", "mypy")
)

task_env = flyte.TaskEnvironment(
    name="hello_ray", resources=flyte.Resources(cpu=(1, 2), memory=("400Mi", "1000Mi")), image=image
)
ray_env = flyte.TaskEnvironment(
    name="ray_env",
    plugin_config=ray_config,
    image=image,
    resources=flyte.Resources(cpu=(3, 4), memory=("3000Mi", "5000Mi")),
    depends_on=[task_env],
)

@task_env.task()
async def hello_ray():
    await asyncio.sleep(20)
    print("Hello from the Ray task!")

@ray_env.task
async def hello_ray_nested(n: int = 3) -> typing.List[int]:
    print("running ray task")
    t = asyncio.create_task(hello_ray())
    futures = [f.remote(i) for i in range(n)]
    res = ray.get(futures)
    await t
    return res

if __name__ == "__main__":
    flyte.init_from_config()
    r = flyte.run(hello_ray_nested)
    print(r.name)
    print(r.url)
    r.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/ray/ray_example.py*

The next example demonstrates how Flyte can create ephemeral Ray clusters and run a subtask that connects to an existing Ray cluster:

```python
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "flyteplugins-ray",
#    "ray[default]==2.46.0"
# ]
# main = "create_ray_cluster"
# params = ""
# ///

import os
import typing

import ray
from flyteplugins.ray.task import HeadNodeConfig, RayJobConfig, WorkerNodeConfig

import flyte.storage

@ray.remote
def f(x):
    return x * x

ray_config = RayJobConfig(
    head_node_config=HeadNodeConfig(ray_start_params={"log-color": "True"}),
    worker_node_config=[WorkerNodeConfig(group_name="ray-group", replicas=2)],
    enable_autoscaling=False,
    shutdown_after_job_finishes=True,
    ttl_seconds_after_finished=3600,
)

image = (
    flyte.Image.from_debian_base(name="ray")
    .with_apt_packages("wget")
    .with_pip_packages("ray[default]==2.46.0", "flyteplugins-ray")
)

task_env = flyte.TaskEnvironment(
    name="ray_client", resources=flyte.Resources(cpu=(1, 2), memory=("400Mi", "1000Mi")), image=image
)
ray_env = flyte.TaskEnvironment(
    name="ray_cluster",
    plugin_config=ray_config,
    image=image,
    resources=flyte.Resources(cpu=(2, 4), memory=("2000Mi", "4000Mi")),
    depends_on=[task_env],
)

@task_env.task()
async def hello_ray(cluster_ip: str) -> typing.List[int]:
    """
    Run a simple Ray task that connects to an existing Ray cluster.
    """
    ray.init(address=f"ray://{cluster_ip}:10001")
    futures = [f.remote(i) for i in range(5)]
    res = ray.get(futures)
    return res

@ray_env.task
async def create_ray_cluster() -> str:
    """
    Create a Ray cluster and return the head node IP address.
    """
    print("creating ray cluster")
    cluster_ip = os.getenv("MY_POD_IP")
    if cluster_ip is None:
        raise ValueError("MY_POD_IP environment variable is not set")
    return f"{cluster_ip}"

if __name__ == "__main__":
    flyte.init_from_config()
    run = flyte.run(create_ray_cluster)
    run.wait()
    print("run url:", run.url)
    print("cluster created, running ray task")
    print("ray address:", run.outputs()[0])
    run = flyte.run(hello_ray, cluster_ip=run.outputs()[0])
    print("run url:", run.url)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/ray/ray_existing_example.py*

## API reference

See the [Ray API reference](https://www.union.ai/docs/v2/union/api-reference/integrations/ray/_index) for full details.

=== PAGE: https://www.union.ai/docs/v2/union/integrations/snowflake ===

# Snowflake

The Snowflake connector lets you run SQL queries against [Snowflake](https://www.snowflake.com/) directly from Flyte tasks. Queries are submitted asynchronously and polled for completion, so they don't block a worker while waiting for results.

The connector supports:

- Parameterized SQL queries with typed inputs
- Key-pair and password-based authentication
- Returns query results as DataFrames
- Automatic links to the Snowflake query dashboard in the Flyte UI
- Query cancellation on task abort

## Installation

```bash
pip install flyteplugins-snowflake
```

This installs the Snowflake Python connector and the `cryptography` library for key-pair authentication.

## Quick start

Here's a minimal example that runs a SQL query on Snowflake:

```python {hl_lines=[2, 4, 12]}
from flyte.io import DataFrame
from flyteplugins.connectors.snowflake import Snowflake, SnowflakeConfig

config = SnowflakeConfig(
    account="myorg-myaccount",
    user="flyte_user",
    database="ANALYTICS",
    schema="PUBLIC",
    warehouse="COMPUTE_WH",
)

count_users = Snowflake(
    name="count_users",
    query_template="SELECT COUNT(*) FROM users",
    plugin_config=config,
    output_dataframe_type=DataFrame,
)
```

This defines a task called `count_users` that runs `SELECT COUNT(*) FROM users` on the configured Snowflake instance. When executed, the connector:

1. Connects to Snowflake using the provided configuration
2. Submits the query asynchronously
3. Polls until the query completes or fails
4. Provides a link to the query in the Snowflake dashboard

![Snowflake Link](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/integrations/snowflake/ui.png)

To run the task, create a `TaskEnvironment` from it and execute it locally or remotely:

```python {hl_lines=3}
import flyte

snowflake_env = flyte.TaskEnvironment.from_task("snowflake_env", count_users)

if __name__ == "__main__":
    flyte.init_from_config()

    # Run locally (connector runs in-process, requires credentials and packages locally)
    run = flyte.with_runcontext(mode="local").run(count_users)

    # Run remotely (connector runs on the control plane)
    run = flyte.with_runcontext(mode="remote").run(count_users)

    print(run.url)
```

> [!NOTE]
> The `TaskEnvironment` created by `from_task` does not need an image or pip packages. Snowflake tasks are connector tasks, which means the query executes on the connector service, not in your task container. In `local` mode, the connector runs in-process and requires `flyteplugins-snowflake` and credentials to be available on your machine. In `remote` mode, the connector runs on the control plane.

## Configuration

The `SnowflakeConfig` dataclass defines the connection settings for your Snowflake instance.

### Required fields

| Field       | Type  | Description                                             |
| ----------- | ----- | ------------------------------------------------------- |
| `account`   | `str` | Snowflake account identifier (e.g. `"myorg-myaccount"`) |
| `database`  | `str` | Target database name                                    |
| `schema`    | `str` | Target schema name (e.g. `"PUBLIC"`)                    |
| `warehouse` | `str` | Compute warehouse to use for query execution            |
| `user`      | `str` | Snowflake username                                      |

### Additional connection parameters

Use `connection_kwargs` to pass any additional parameters supported by the [Snowflake Python connector](https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-api). This is a dictionary that gets forwarded directly to `snowflake.connector.connect()`.

Common options include:

| Parameter       | Type  | Description                                                                |
| --------------- | ----- | -------------------------------------------------------------------------- |
| `role`          | `str` | Snowflake role to use for the session                                      |
| `authenticator` | `str` | Authentication method (e.g. `"snowflake"`, `"externalbrowser"`, `"oauth"`) |
| `token`         | `str` | OAuth token when using `authenticator="oauth"`                             |
| `login_timeout` | `int` | Timeout in seconds for the login request                                   |

Example with a role:

```python {hl_lines=8}
config = SnowflakeConfig(
    account="myorg-myaccount",
    user="flyte_user",
    database="ANALYTICS",
    schema="PUBLIC",
    warehouse="COMPUTE_WH",
    connection_kwargs={
        "role": "DATA_ANALYST",
    },
)
```

## Authentication

The connector supports two authentication approaches: key-pair authentication, and password-based or other authentication methods provided through `connection_kwargs`.

### Key-pair authentication

Key-pair authentication is the recommended approach for automated workloads. Pass the names of the Flyte secrets containing the private key and optional passphrase:

```python {hl_lines=[5, 6]}
query = Snowflake(
    name="secure_query",
    query_template="SELECT * FROM sensitive_data",
    plugin_config=config,
    snowflake_private_key="my-snowflake-private-key",
    snowflake_private_key_passphrase="my-snowflake-pk-passphrase",
)
```

The `snowflake_private_key` parameter is the name of the secret (or secret key) that contains your PEM-encoded private key. The `snowflake_private_key_passphrase` parameter is the name of the secret (or secret key) that contains the passphrase, if your key is encrypted. If your key is not encrypted, omit the passphrase parameter.

The connector decodes the PEM key and converts it to DER format for Snowflake authentication.

> [!NOTE]
> If your credentials are stored in a secret group, you can pass `secret_group` to the `Snowflake` task. The plugin expects `snowflake_private_key` and
> `snowflake_private_key_passphrase` to be keys within the same secret group.

### Password authentication

Send the password via `connection_kwargs`:

```python {hl_lines=8}
config = SnowflakeConfig(
    account="myorg-myaccount",
    user="flyte_user",
    database="ANALYTICS",
    schema="PUBLIC",
    warehouse="COMPUTE_WH",
    connection_kwargs={
        "password": "my-password",
    },
)
```

### OAuth authentication

For OAuth-based authentication, specify the authenticator and token:

```python {hl_lines=["8-9"]}
config = SnowflakeConfig(
    account="myorg-myaccount",
    user="flyte_user",
    database="ANALYTICS",
    schema="PUBLIC",
    warehouse="COMPUTE_WH",
    connection_kwargs={
        "authenticator": "oauth",
        "token": "<oauth-token>",
    },
)
```

## Query templating

Use the `inputs` parameter to define typed inputs for your query. Input values are bound using the `%(param)s` syntax supported by the [Snowflake Python connector](https://docs.snowflake.com/en/developer-guide/python-connector/python-connector-api), which handles type conversion and escaping automatically.

### Supported input types

The `inputs` dictionary maps parameter names to Python values. Supported scalar types include `str`, `int`, `float`, and `bool`.

To insert multiple rows in a single query, you can also provide lists as input values. When using list inputs, be sure to set `batch=True` on the `Snowflake` task. This enables automatic batching, where the inputs are expanded and sent as a single multi-row query instead of you having to write multiple individual statements.

### Batched `INSERT` with list inputs

When `batch=True` is enabled, a parameterized `INSERT` query with list inputs is automatically expanded into a multi-row `VALUES` statement.

Example:

```python
query = "INSERT INTO t (a, b) VALUES (%(a)s, %(b)s)"
inputs = {"a": [1, 2], "b": ["x", "y"]}
```

This is expanded into:

```sql
INSERT INTO t (a, b)
VALUES (%(a_0)s, %(b_0)s), (%(a_1)s, %(b_1)s)
```

with the following flattened parameters:

```python
flat_params = {
    "a_0": 1,
    "b_0": "x",
    "a_1": 2,
    "b_1": "y",
}
```

#### Constraints

- The query must contain exactly one `VALUES (...)` clause.
- All list inputs must have the same non-zero length.

### Parameterized `SELECT`

```python {hl_lines=[5, 7]}
from flyte.io import DataFrame

events_by_date = Snowflake(
    name="events_by_date",
    query_template="SELECT * FROM events WHERE event_date = %(event_date)s",
    plugin_config=config,
    inputs={"event_date": str},
    output_dataframe_type=DataFrame,
)
```

You can call the task with the required inputs:

```python {hl_lines=3}
@env.task
async def fetch_events() -> DataFrame:
    return await events_by_date(event_date="2024-01-15")
```

### Multiple inputs

You can define multiple input parameters of different types:

```python {hl_lines=["4-8", "12-15"]}
filtered_events = Snowflake(
    name="filtered_events",
    query_template="""
        SELECT * FROM events
        WHERE event_date >= %(start_date)s
          AND event_date <= %(end_date)s
          AND region = %(region)s
          AND score > %(min_score)s
    """,
    plugin_config=config,
    inputs={
        "start_date": str,
        "end_date": str,
        "region": str,
        "min_score": float,
    },
    output_dataframe_type=DataFrame,
)
```

> [!NOTE]
> The query template is normalized before execution: newlines and tabs are replaced with spaces, and consecutive whitespace is collapsed. You can format your queries across multiple lines for readability without affecting execution.

## Retrieving query results

If your query produces output, set `output_dataframe_type` to capture the results. `output_dataframe_type` accepts `DataFrame` from `flyte.io`. This is a meta-wrapper type that represents tabular results and can be materialized into a concrete DataFrame implementation using `open()` where you specify the target type and `all()`.

```python {hl_lines=13}
from flyte.io import DataFrame

top_customers = Snowflake(
    name="top_customers",
    query_template="""
        SELECT customer_id, SUM(amount) AS total_spend
        FROM orders
        GROUP BY customer_id
        ORDER BY total_spend DESC
        LIMIT 100
    """,
    plugin_config=config,
    output_dataframe_type=DataFrame,
)
```

At present, only `pandas.DataFrame` is supported. The results are returned directly when you call the task:

```python {hl_lines=6}
import pandas as pd

@env.task
async def analyze_top_customers() -> dict:
    df = await top_customers()
    pandas_df = await df.open(pd.DataFrame).all()
    total_spend = pandas_df["total_spend"].sum()
    return {"total_spend": float(total_spend)}
```

If you specify `pandas.DataFrame` as the `output_dataframe_type`, you do not need to call `open()` and `all()` to materialize the results.

```python {hl_lines=[1, 13, "18-19"]}
import pandas as pd

top_customers = Snowflake(
    name="top_customers",
    query_template="""
        SELECT customer_id, SUM(amount) AS total_spend
        FROM orders
        GROUP BY customer_id
        ORDER BY total_spend DESC
        LIMIT 100
    """,
    plugin_config=config,
    output_dataframe_type=pd.DataFrame,
)

@env.task
async def analyze_top_customers() -> dict:
    df = await top_customers()
    total_spend = df["total_spend"].sum()
    return {"total_spend": float(total_spend)}
```

> [!NOTE]
> Be sure to inject the `SNOWFLAKE_PRIVATE_KEY` and `SNOWFLAKE_PRIVATE_KEY_PASSPHRASE` environment variables as secrets into your downstream tasks, as they must have access to Snowflake credentials in order to retrieve the DataFrame results. More on how you can create secrets [here](https://www.union.ai/docs/v2/union/user-guide/task-configuration/secrets/page.md).

If you don't need query results (for example, `DDL` statements or `INSERT` queries), omit `output_dataframe_type`.

## End-to-end example

Here's a complete workflow that uses the Snowflake connector as part of a data pipeline. The workflow creates a staging table, inserts records, queries aggregated results and processes them in a downstream task.

```
import flyte
from flyte.io import DataFrame
from flyteplugins.connectors.snowflake import Snowflake, SnowflakeConfig

config = SnowflakeConfig(
    account="myorg-myaccount",
    user="flyte_user",
    database="ANALYTICS",
    schema="PUBLIC",
    warehouse="COMPUTE_WH",
    connection_kwargs={
        "role": "ETL_ROLE",
    },
)

# Step 1: Create the staging table if it doesn't exist
create_staging = Snowflake(
    name="create_staging",
    query_template="""
        CREATE TABLE IF NOT EXISTS staging.daily_events (
            event_id STRING,
            event_date DATE,
            user_id STRING,
            event_type STRING,
            payload VARIANT
        )
    """,
    plugin_config=config,
    snowflake_private_key="snowflake",
    snowflake_private_key_passphrase="snowflake_passphrase",
)

# Step 2: Insert a record into the staging table
insert_events = Snowflake(
    name="insert_event",
    query_template="""
        INSERT INTO staging.daily_events (event_id, event_date, user_id, event_type)
        VALUES (%(event_id)s, %(event_date)s, %(user_id)s, %(event_type)s)
    """,
    plugin_config=config,
    inputs={
        "event_id": list[str],
        "event_date": list[str],
        "user_id": list[str],
        "event_type": list[str],
    },
    snowflake_private_key="snowflake",
    snowflake_private_key_passphrase="snowflake_passphrase",
    batch=True,
)

# Step 3: Query aggregated results for a given date
daily_summary = Snowflake(
    name="daily_summary",
    query_template="""
        SELECT
            event_type,
            COUNT(*) AS event_count,
            COUNT(DISTINCT user_id) AS unique_users
        FROM staging.daily_events
        WHERE event_date = %(report_date)s
        GROUP BY event_type
        ORDER BY event_count DESC
    """,
    plugin_config=config,
    inputs={"report_date": str},
    output_dataframe_type=DataFrame,
    snowflake_private_key="snowflake",
    snowflake_private_key_passphrase="snowflake_passphrase",
)

# Create environments for all Snowflake tasks
snowflake_env = flyte.TaskEnvironment.from_task(
    "snowflake_env", create_staging, insert_events, daily_summary
)

# Main pipeline environment that depends on the Snowflake task environments
env = flyte.TaskEnvironment(
    name="analytics_env",
    resources=flyte.Resources(memory="512Mi"),
    image=flyte.Image.from_debian_base(name="analytics").with_pip_packages(
        "flyteplugins-snowflake", pre=True
    ),
    secrets=[
        flyte.Secret(key="snowflake", as_env_var="SNOWFLAKE_PRIVATE_KEY"),
        flyte.Secret(
            key="snowflake_passphrase", as_env_var="SNOWFLAKE_PRIVATE_KEY_PASSPHRASE"
        ),
    ],
    depends_on=[snowflake_env],
)

# Step 4: Process the results in Python
@env.task
async def generate_report(summary: DataFrame) -> dict:
    import pandas as pd

    df = await summary.open(pd.DataFrame).all()
    total_events = df["event_count"].sum()
    top_event = df.iloc[0]["event_type"]
    return {
        "total_events": int(total_events),
        "top_event_type": top_event,
        "event_types_count": len(df),
    }

# Compose the pipeline
@env.task
async def run_daily_pipeline(
    event_ids: list[str],
    event_dates: list[str],
    user_ids: list[str],
    event_types: list[str],
) -> dict:
    await create_staging()
    await insert_events(
        event_id=event_ids,
        event_date=event_dates,
        user_id=user_ids,
        event_type=event_types,
    )
    summary = await daily_summary(report_date=event_dates[0])
    return await generate_report(summary=summary)

if __name__ == "__main__":
    flyte.init_from_config()

    # Run locally
    run = flyte.with_runcontext(mode="local").run(
        run_daily_pipeline,
        event_ids=["event-1", "event-2"],
        event_dates=["2023-01-01", "2023-01-02"],
        user_ids=["user-1", "user-2"],
        event_types=["click", "view"],
    )

    # Or run remotely
    run = flyte.with_runcontext(mode="remote").run(
        run_daily_pipeline,
        event_ids=["event-1", "event-2"],
        event_dates=["2023-01-01", "2023-01-02"],
        user_ids=["user-1", "user-2"],
        event_types=["click", "view"],
    )

    print(run.url)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/connectors/snowflake/example.py*

=== PAGE: https://www.union.ai/docs/v2/union/integrations/spark ===

# Spark

The Spark plugin lets you run [Apache Spark](https://spark.apache.org/) jobs natively on Kubernetes. Flyte manages the full cluster lifecycle: provisioning a transient Spark cluster for each task execution, running the job, and tearing the cluster down on completion.

Under the hood, the plugin uses the [Spark on Kubernetes Operator](https://github.com/GoogleCloudPlatform/spark-on-k8s-operator) to create and manage Spark applications. No external Spark service or long-running cluster is required.

## When to use this plugin

- Large-scale data processing and ETL pipelines
- Jobs that benefit from Spark's distributed execution engine (Spark SQL, PySpark, Spark MLlib)
- Workloads that need Hadoop-compatible storage access (S3, GCS, HDFS)

## Installation

```bash
pip install flyteplugins-spark
```

## Configuration

Create a `Spark` configuration and pass it as `plugin_config` to a `TaskEnvironment`:

```python
from flyteplugins.spark import Spark

spark_config = Spark(
    spark_conf={
        "spark.driver.memory": "3000M",
        "spark.executor.memory": "1000M",
        "spark.executor.cores": "1",
        "spark.executor.instances": "2",
        "spark.driver.cores": "1",
    },
)

spark_env = flyte.TaskEnvironment(
    name="spark_env",
    plugin_config=spark_config,
    image=image,
)
```

### `Spark` parameters

| Parameter | Type | Description |
|-----------|------|-------------|
| `spark_conf` | `Dict[str, str]` | Spark configuration key-value pairs (e.g., executor memory, cores, instances) |
| `hadoop_conf` | `Dict[str, str]` | Hadoop configuration key-value pairs (e.g., S3/GCS access settings) |
| `executor_path` | `str` | Path to the Python binary for PySpark executors |
| `applications_path` | `str` | Path to the main Spark application file |
| `driver_pod` | `PodTemplate` | Pod template for the Spark driver pod |
| `executor_pod` | `PodTemplate` | Pod template for the Spark executor pods |

### Accessing the Spark session

Inside a Spark task, the `SparkSession` is available through the task context:

```python
from flyte._context import internal_ctx

@spark_env.task
async def my_spark_task() -> float:
    ctx = internal_ctx()
    spark = ctx.data.task_context.data["spark_session"]
    # Use spark as a normal SparkSession
    df = spark.read.parquet("s3://my-bucket/data.parquet")
    return df.count()
```

### Overriding configuration at runtime

You can override Spark configuration for individual task calls using `.override()`:

```python
from copy import deepcopy

updated_config = deepcopy(spark_config)
updated_config.spark_conf["spark.executor.instances"] = "4"

result = await my_spark_task.override(plugin_config=updated_config)()
```

## Example

```python
# /// script
# requires-python = "==3.13"
# dependencies = [
#    "flyte>=2.0.0b52",
#    "flyteplugins-spark"
# ]
# main = "hello_spark_nested"
# params = "3"
# ///

import random
from copy import deepcopy
from operator import add

from flyteplugins.spark.task import Spark

import flyte.remote
from flyte._context import internal_ctx

image = (
    flyte.Image.from_base("apache/spark-py:v3.4.0")
    .clone(name="spark", python_version=(3, 10), registry="ghcr.io/flyteorg")
    .with_pip_packages("flyteplugins-spark", pre=True)
)

task_env = flyte.TaskEnvironment(
    name="get_pi", resources=flyte.Resources(cpu=(1, 2), memory=("400Mi", "1000Mi")), image=image
)

spark_conf = Spark(
    spark_conf={
        "spark.driver.memory": "3000M",
        "spark.executor.memory": "1000M",
        "spark.executor.cores": "1",
        "spark.executor.instances": "2",
        "spark.driver.cores": "1",
        "spark.kubernetes.file.upload.path": "/opt/spark/work-dir",
        "spark.jars": "https://storage.googleapis.com/hadoop-lib/gcs/gcs-connector-hadoop3-latest.jar,https://repo1.maven.org/maven2/org/apache/hadoop/hadoop-aws/3.2.2/hadoop-aws-3.2.2.jar,https://repo1.maven.org/maven2/com/amazonaws/aws-java-sdk-bundle/1.12.262/aws-java-sdk-bundle-1.12.262.jar",
    },
)

spark_env = flyte.TaskEnvironment(
    name="spark_env",
    resources=flyte.Resources(cpu=(1, 2), memory=("3000Mi", "5000Mi")),
    plugin_config=spark_conf,
    image=image,
    depends_on=[task_env],
)

def f(_):
    x = random.random() * 2 - 1
    y = random.random() * 2 - 1
    return 1 if x**2 + y**2 <= 1 else 0

@task_env.task
async def get_pi(count: int, partitions: int) -> float:
    return 4.0 * count / partitions

@spark_env.task
async def hello_spark_nested(partitions: int = 3) -> float:
    n = 1 * partitions
    ctx = internal_ctx()
    spark = ctx.data.task_context.data["spark_session"]
    count = spark.sparkContext.parallelize(range(1, n + 1), partitions).map(f).reduce(add)

    return await get_pi(count, partitions)

@task_env.task
async def spark_overrider(executor_instances: int = 3, partitions: int = 4) -> float:
    updated_spark_conf = deepcopy(spark_conf)
    updated_spark_conf.spark_conf["spark.executor.instances"] = str(executor_instances)
    return await hello_spark_nested.override(plugin_config=updated_spark_conf)(partitions=partitions)

if __name__ == "__main__":
    flyte.init_from_config()
    r = flyte.run(hello_spark_nested)
    print(r.name)
    print(r.url)
    r.wait()
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/spark/spark_example.py*

## API reference

See the [Spark API reference](https://www.union.ai/docs/v2/union/api-reference/integrations/spark/_index) for full details.

=== PAGE: https://www.union.ai/docs/v2/union/integrations/wandb ===

# Weights & Biases

[Weights & Biases](https://wandb.ai) (W&B) is a platform for tracking machine learning experiments, visualizing metrics and optimizing hyperparameters. This plugin integrates W&B with Flyte, enabling you to:

- Automatically initialize W&B runs in your tasks without boilerplate
- Link directly from the Flyte UI to your W&B runs and sweeps
- Share W&B runs across parent and child tasks
- Track distributed training jobs across multiple GPUs and nodes
- Run hyperparameter sweeps with parallel agents

## Installation

```bash
pip install flyteplugins-wandb
```

You also need a W&B API key. Store it as a Flyte secret so your tasks can authenticate with W&B.

## Quick start

Here's a minimal example that logs metrics to W&B from a Flyte task:

```
import flyte

from flyteplugins.wandb import get_wandb_run, wandb_config, wandb_init

env = flyte.TaskEnvironment(
    name="wandb-example",
    image=flyte.Image.from_debian_base(name="wandb-example").with_pip_packages(
        "flyteplugins-wandb"
    ),
    secrets=[flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY")],
)

@wandb_init
@env.task
async def train_model() -> str:
    wandb_run = get_wandb_run()

    # Your training code here
    for epoch in range(10):
        loss = 1.0 / (epoch + 1)
        wandb_run.log({"epoch": epoch, "loss": loss})

    return "Training complete"

if __name__ == "__main__":
    flyte.init_from_config()

    r = flyte.with_runcontext(
        custom_context=wandb_config(
            project="my-project",
            entity="my-team",
        ),
    ).run(train_model)

    print(f"run url: {r.url}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/wandb/quick_start.py*

This example demonstrates the core pattern:

1. **Define a task environment** with the plugin installed and your W&B API key as a secret
2. **Decorate your task** with `@wandb_init` (must be the outermost decorator, above `@env.task`)
3. **Access the run** with `get_wandb_run()` to log metrics
4. **Provide configuration** via `wandb_config()` when running the task

The plugin handles calling `wandb.init()` and `wandb.finish()` for you, and automatically adds a link to the W&B run in the Flyte UI.

![UI](https://raw.githubusercontent.com/unionai/unionai-docs-static/main/images/integrations/wandb/ui.png)

## What's next

This integration guide is split into focused sections, depending on how you want to use Weights & Biases with Flyte:

- ****Weights & Biases > Experiments****: Create and manage W&B runs from Flyte tasks.
- ****Weights & Biases > Distributed training****: Track experiments across multi-GPU and multi-node training jobs.
- ****Weights & Biases > Sweeps****: Run hyperparameter searches and manage sweep execution from Flyte tasks.
- ****Weights & Biases > Downloading logs****: Download logs and execution metadata from Weights & Biases.
- ****Weights & Biases > Constraints and best practices****: Learn about limitations, edge cases and recommended patterns.
- ****Weights & Biases > Manual integration****: Use Weights & Biases directly in Flyte tasks without decorators or helpers.

> **📝 Note**
>
> We've included additional examples developed while testing edge cases of the plugin [here](https://github.com/flyteorg/flyte-sdk/tree/main/plugins/wandb/examples).

=== PAGE: https://www.union.ai/docs/v2/union/integrations/wandb/experiments ===

# Experiments

The `@wandb_init` decorator automatically initializes a W&B run when your task executes and finishes it when the task completes. This section covers the different ways to use it.

## Basic usage

Apply `@wandb_init` as the outermost decorator on your task:

```python {hl_lines=1}
@wandb_init
@env.task
async def my_task() -> str:
    run = get_wandb_run()
    run.log({"metric": 42})
    return "done"
```

The decorator:

- Calls `wandb.init()` before your task code runs
- Calls `wandb.finish()` after your task completes (or fails)
- Adds a link to the W&B run in the Flyte UI

You can also use it on synchronous tasks:

```python {hl_lines=[1, 3]}
@wandb_init
@env.task
def my_sync_task() -> str:
    run = get_wandb_run()
    run.log({"metric": 42})
    return "done"
```

## Accessing the run object

Use `get_wandb_run()` to access the current W&B run object:

```python {hl_lines=6}
from flyteplugins.wandb import get_wandb_run

@wandb_init
@env.task
async def train() -> str:
    run = get_wandb_run()

    # Log metrics
    run.log({"loss": 0.5, "accuracy": 0.9})

    # Access run properties
    print(f"Run ID: {run.id}")
    print(f"Run URL: {run.url}")
    print(f"Project: {run.project}")

    # Log configuration
    run.config.update({"learning_rate": 0.001, "batch_size": 32})

    return run.id
```

## Parent-child task relationships

When a parent task calls child tasks, the plugin can share the same W&B run across all of them. This is useful for tracking an entire workflow in a single run.

```python {hl_lines=[1, 9, 16]}
@wandb_init
@env.task
async def child_task(x: int) -> int:
    run = get_wandb_run()
    run.log({"child_metric": x * 2})
    return x * 2

@wandb_init
@env.task
async def parent_task() -> int:
    run = get_wandb_run()
    run.log({"parent_metric": 100})

    # Child task shares the parent's run by default
    result = await child_task(5)

    return result
```

By default (`run_mode="auto"`), child tasks reuse their parent's W&B run. All metrics logged by the parent and children appear in the same run in the W&B UI.

## Run modes

The `run_mode` parameter controls how tasks create or reuse W&B runs. There are three modes:

| Mode             | Behavior                                                                   |
| ---------------- | -------------------------------------------------------------------------- |
| `auto` (default) | Create a new run if no parent run exists, otherwise reuse the parent's run |
| `new`            | Always create a new run, even if a parent run exists                       |
| `shared`         | Always reuse the parent's run (fails if no parent run exists)              |

### Using `run_mode="new"` for independent runs

```python {hl_lines=1}
@wandb_init(run_mode="new")
@env.task
async def independent_child(x: int) -> int:
    run = get_wandb_run()
    # This task gets its own separate run
    run.log({"independent_metric": x})
    return x

@wandb_init
@env.task
async def parent_task() -> str:
    run = get_wandb_run()
    parent_run_id = run.id

    # This child creates its own run
    await independent_child(5)

    # Parent's run is unchanged
    assert run.id == parent_run_id
    return parent_run_id
```

### Using `run_mode="shared"` for explicit sharing

```python {hl_lines=1}
@wandb_init(run_mode="shared")
@env.task
async def must_share_run(x: int) -> int:
    # This task requires a parent run to exist
    # It will fail if called as a top-level task
    run = get_wandb_run()
    run.log({"shared_metric": x})
    return x
```

## Configuration with `wandb_config`

Use `wandb_config()` to configure W&B runs. You can set it at the workflow level or override it for specific tasks, allowing you to provide configuration values at runtime.

### Workflow-level configuration

```python {hl_lines=["5-9"]}
if __name__ == "__main__":
    flyte.init_from_config()

    flyte.with_runcontext(
        custom_context=wandb_config(
            project="my-project",
            entity="my-team",
            tags=["experiment-1", "production"],
            config={"model": "resnet50", "dataset": "imagenet"},
        ),
    ).run(train_task)
```

### Overriding configuration for child tasks

Use `wandb_config()` as a context manager to override settings for specific child task calls:

```python {hl_lines=[8, 12]}
@wandb_init
@env.task
async def parent_task() -> str:
    run = get_wandb_run()
    run.log({"parent_metric": 100})

    # Override tags and config for this child call
    with wandb_config(tags=["special-run"], config={"learning_rate": 0.01}):
        await child_task(10)

    # Override run_mode for this child call
    with wandb_config(run_mode="new"):
        await child_task(20)  # Gets its own run

    return "done"
```

## Using traces with W&B runs

Flyte traces can access the parent task's W&B run without needing the `@wandb_init` decorator. This is useful for helper functions that should log to the same run:

```python {hl_lines=[1, 3]}
@flyte.trace
async def log_validation_metrics(accuracy: float, f1: float):
    run = get_wandb_run()
    if run:
        run.log({"val_accuracy": accuracy, "val_f1": f1})

@wandb_init
@env.task
async def train_and_validate() -> str:
    run = get_wandb_run()

    # Training loop
    for epoch in range(10):
        run.log({"train_loss": 1.0 / (epoch + 1)})

    # Trace logs to the same run
    await log_validation_metrics(accuracy=0.95, f1=0.92)

    return "done"
```

> **📝 Note**
>
> Do not apply `@wandb_init` to traces. Traces automatically access the parent task's run via `get_wandb_run()`.

=== PAGE: https://www.union.ai/docs/v2/union/integrations/wandb/distributed_training ===

# Distributed training

When running distributed training jobs, multiple processes run simultaneously across GPUs. The `@wandb_init` decorator automatically detects distributed training environments and coordinates W&B logging across processes.

The plugin:

- Auto-detects distributed context from environment variables (set by launchers like `torchrun`)
- Controls which processes initialize W&B runs based on the `run_mode` and `rank_scope` parameters
- Generates unique run IDs that distinguish between workers and ranks
- Adds links to W&B runs in the Flyte UI

## Quick start

Here's a minimal single-node example that logs metrics from a distributed training task. By default (`run_mode="auto"`, `rank_scope="global"`), only rank 0 logs to W&B:

```
import flyte
import torch
import torch.distributed
from flyteplugins.pytorch.task import Elastic
from flyteplugins.wandb import get_wandb_run, wandb_config, wandb_init

image = flyte.Image.from_debian_base(name="torch-wandb").with_pip_packages(
    "flyteplugins-wandb", "flyteplugins-pytorch"
)

env = flyte.TaskEnvironment(
    name="distributed_env",
    image=image,
    resources=flyte.Resources(gpu="A100:2"),
    plugin_config=Elastic(nproc_per_node=2, nnodes=1),
    secrets=flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY"),
)

@wandb_init
@env.task
def train() -> float:
    torch.distributed.init_process_group("nccl")

    # Only rank 0 gets a W&B run object; others get None
    run = get_wandb_run()

    # Simulate training
    for step in range(100):
        loss = 1.0 / (step + 1)

        # Safe to call on all ranks - only rank 0 actually logs
        if run:
            run.log({"loss": loss, "step": step})

    torch.distributed.destroy_process_group()
    return loss

if __name__ == "__main__":
    flyte.init_from_config()
    flyte.with_runcontext(
        custom_context=wandb_config(project="my-project", entity="my-team")
    ).run(train)
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/wandb/distributed_training_quick_start.py*

A few things to note:

1. Use the `Elastic` plugin to configure distributed training (number of processes, nodes)
2. Apply `@wandb_init` as the outermost decorator
3. Check if `run` is not None before logging - only the primary rank has a run object in `auto` mode

> **📝 Note**
>
> The `if run:` check is always safe regardless of run mode. In `shared` and `new` modes all ranks get a run object, but the check doesn't hurt and keeps your code portable across modes.

![Single-node auto](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/integrations/wandb/single_node_auto_flyte.png)

## Run modes in distributed training

The `run_mode` parameter controls how W&B runs are created across distributed processes. The behavior differs between single-node (one machine, multiple GPUs) and multi-node (multiple machines) setups.

### Single-node behavior

| Mode             | Which ranks log       | Result                                 |
| ---------------- | --------------------- | -------------------------------------- |
| `auto` (default) | Only rank 0           | 1 W&B run                              |
| `shared`         | All ranks to same run | 1 W&B run with metrics labeled by rank |
| `new`            | Each rank separately  | N W&B runs (grouped in UI)             |

### Multi-node behavior

For multi-node training, the `rank_scope` parameter controls the granularity of W&B runs:

- **`global`** (default): Treat all workers as one unit
- **`worker`**: Treat each worker/node independently

The combination of `run_mode` and `rank_scope` determines logging behavior:

| `run_mode` | `rank_scope` | Who initializes W&B    | W&B Runs | Grouping |
| ---------- | ------------ | ---------------------- | -------- | -------- |
| `auto`     | `global`     | Global rank 0 only     | 1        | -        |
| `auto`     | `worker`     | Local rank 0 per worker | N        | -        |
| `shared`   | `global`     | All ranks (shared globally) | 1        | -        |
| `shared`   | `worker`     | All ranks (shared per worker) | N        | -        |
| `new`      | `global`     | All ranks              | N × M    | 1 group  |
| `new`      | `worker`     | All ranks              | N × M    | N groups |

Where `N` = number of workers/nodes, `M` = processes per worker.

### Choosing run mode and rank scope

- **`auto`** (recommended): Use when you want clean dashboards with minimal runs. Most metrics (loss, accuracy) are the same across ranks after gradient synchronization, so logging from one rank is sufficient.
- **`shared`**: Use when you need to compare metrics across ranks in a single view. Each rank's metrics are labeled with an `x_label` identifier. Useful for debugging load imbalance or per-GPU throughput.
- **`new`**: Use when you need completely separate runs per GPU, for example to track GPU-specific metrics or compare training dynamics across devices.

For multi-node training:
- Use **`rank_scope="global"`** (default) for most cases. A single consolidated run across all nodes is sufficient since metrics like loss and accuracy converge after gradient synchronization.
- Use **`rank_scope="worker"`** for debugging and per-node analysis. This is useful when you need to inspect data distribution across nodes, compare predictions from different workers, or track metrics on individual batches outside the main node.

## Single-node multi-GPU

For single-node distributed training, configure the `Elastic` plugin with `nnodes=1` and set `nproc_per_node` to your GPU count.

### Basic example with `auto` mode

```python {hl_lines=["6-7", 13, 18, 30]}
import os

import torch
import torch.distributed
import flyte
from flyteplugins.pytorch.task import Elastic
from flyteplugins.wandb import wandb_init, get_wandb_run

env = flyte.TaskEnvironment(
    name="single_node_env",
    image=image,
    resources=flyte.Resources(gpu="A100:4"),
    plugin_config=Elastic(nproc_per_node=4, nnodes=1),
    secrets=flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY"),
)

@wandb_init # run_mode="auto" (default)
@env.task
def train_single_node() -> float:
    torch.distributed.init_process_group("nccl")
    rank = torch.distributed.get_rank()
    local_rank = int(os.environ.get("LOCAL_RANK", 0))

    device = torch.device(f"cuda:{local_rank}")
    torch.cuda.set_device(device)

    run = get_wandb_run()

    # Training loop - only rank 0 logs
    for epoch in range(10):
        loss = train_epoch(model, dataloader, device)

        if run:
            run.log({"epoch": epoch, "loss": loss})

    torch.distributed.destroy_process_group()
    return loss
```

### Using `shared` mode for per-rank metrics

When you need to see metrics from all GPUs in a single run, use `run_mode="shared"`:

```python {hl_lines=[3, 13, 19]}
import os

@wandb_init(run_mode="shared")
@env.task
def train_with_per_gpu_metrics() -> float:
    torch.distributed.init_process_group("nccl")
    rank = torch.distributed.get_rank()
    local_rank = int(os.environ.get("LOCAL_RANK", 0))

    device = torch.device(f"cuda:{local_rank}")
    torch.cuda.set_device(device)

    # In shared mode, all ranks get a run object
    run = get_wandb_run()

    for step in range(1000):
        loss, throughput = train_step(model, batch, device)

        # Each rank logs with automatic x_label identification
        if run:
            run.log({
                "loss": loss,
                "throughput_samples_per_sec": throughput,
                "gpu_memory_used": torch.cuda.memory_allocated(device),
            })

    torch.distributed.destroy_process_group()
    return loss
```

![Single-node shared](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/integrations/wandb/single_node_shared_flyte.png)

In the W&B UI, metrics from each rank appear with distinct labels, allowing you to compare GPU utilization and throughput across devices.

![Single-node shared W&B UI](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/integrations/wandb/single_node_shared_wandb.png)

### Using `new` mode for per-rank runs

When you need completely separate W&B runs for each GPU, use `run_mode="new"`. Each rank gets its own run, and runs are grouped together in the W&B UI:

```python {hl_lines=[1, "11-12"]}
@wandb_init(run_mode="new")  # Each rank gets its own run
@env.task
def train_per_rank() -> float:
    torch.distributed.init_process_group("nccl")
    rank = torch.distributed.get_rank()
    # ...

    # Each rank has its own W&B run
    run = get_wandb_run()

    # Run IDs: {base}-rank-{rank}
    # All runs are grouped under {base} in W&B UI
    run.log({"train/loss": loss.item(), "rank": rank})
    # ...
```

With `run_mode="new"`:

- Each rank creates its own W&B run
- Run IDs follow the pattern `{run_name}-{action_name}-rank-{rank}`
- All runs are grouped together in the W&B UI for comparison

## Multi-node training with `Elastic`

For multi-node distributed training, set `nnodes` to your node count. The `rank_scope` parameter controls whether you get a single W&B run across all nodes (`global`) or one run per node (`worker`).

### Global scope (default): Single run across all nodes

With `run_mode="auto"` and `rank_scope="global"` (both defaults), only global rank 0 initializes W&B, resulting in a single run for the entire distributed job:

```python {hl_lines=["11-12", "27-30", "35", "59-60", "95-98"]}
import os

import torch
import torch.distributed
import torch.nn as nn
import torch.optim as optim
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler

import flyte
from flyteplugins.pytorch.task import Elastic
from flyteplugins.wandb import wandb_init, wandb_config, get_wandb_run

image = flyte.Image.from_debian_base(name="torch-wandb").with_pip_packages(
    "flyteplugins-wandb", "flyteplugins-pytorch", pre=True
)

multi_node_env = flyte.TaskEnvironment(
    name="multi_node_env",
    image=image,
    resources=flyte.Resources(
        cpu=(1, 2),
        memory=("1Gi", "10Gi"),
        gpu="A100:4",
        shm="auto",
    ),
    plugin_config=Elastic(
        nproc_per_node=4,  # GPUs per node
        nnodes=2,          # Number of nodes
    ),
    secrets=flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY"),
)

@wandb_init  # rank_scope="global" by default → 1 run total
@multi_node_env.task
def train_multi_node(epochs: int, batch_size: int) -> float:
    torch.distributed.init_process_group("nccl")

    rank = torch.distributed.get_rank()
    world_size = torch.distributed.get_world_size()
    local_rank = int(os.environ.get("LOCAL_RANK", 0))

    device = torch.device(f"cuda:{local_rank}")
    torch.cuda.set_device(device)

    # Model with DDP
    model = MyModel().to(device)
    model = DDP(model, device_ids=[local_rank])

    # Distributed data loading
    dataset = MyDataset()
    sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
    dataloader = DataLoader(dataset, batch_size=batch_size, sampler=sampler)

    optimizer = optim.AdamW(model.parameters(), lr=1e-3)
    criterion = nn.CrossEntropyLoss()

    # Only global rank 0 gets a W&B run
    run = get_wandb_run()

    for epoch in range(epochs):
        sampler.set_epoch(epoch)
        model.train()

        for batch_idx, (data, target) in enumerate(dataloader):
            data, target = data.to(device), target.to(device)

            optimizer.zero_grad()
            output = model(data)
            loss = criterion(output, target)
            loss.backward()
            optimizer.step()

            if run and batch_idx % 100 == 0:
                run.log({
                    "train/loss": loss.item(),
                    "train/epoch": epoch,
                    "train/batch": batch_idx,
                })

        if run:
            run.log({"train/epoch_complete": epoch})

    # Barrier ensures all ranks finish before cleanup
    torch.distributed.barrier()
    torch.distributed.destroy_process_group()

    return loss.item()

if __name__ == "__main__":
    flyte.init_from_config()
    flyte.with_runcontext(
        custom_context=wandb_config(
            project="multi-node-training",
            tags=["distributed", "multi-node"],
        )
    ).run(train_multi_node, epochs=10, batch_size=32)
```

With this configuration:

- Two nodes run the task, each with 4 GPUs (8 total processes)
- Only global rank 0 creates a W&B run
- Run ID follows the pattern `{run_name}-{action_name}`
- The Flyte UI shows a single link to the W&B run

### Worker scope: One run per node

Use `rank_scope="worker"` when you want each node to have its own W&B run for per-node analysis:

```python {hl_lines=[1, 8]}
@wandb_init(rank_scope="worker")  # 1 run per worker/node
@multi_node_env.task
def train_per_worker(epochs: int, batch_size: int) -> float:
    torch.distributed.init_process_group("nccl")
    local_rank = int(os.environ.get("LOCAL_RANK", 0))
    # ...

    # Local rank 0 of each worker gets a W&B run
    run = get_wandb_run()

    if run:
        # Each worker logs to its own run
        run.log({"train/loss": loss.item()})
    # ...
```

With `run_mode="auto"`, `rank_scope="worker"`:

- Each node's local rank 0 creates a W&B run
- Run IDs follow the pattern `{run_name}-{action_name}-worker-{worker_index}`
- The Flyte UI shows links to each worker's W&B run

![Multi-node](https://raw.githubusercontent.com/unionai/unionai-docs-static/refs/heads/main/images/integrations/wandb/multi_node.png)

### Shared mode: All ranks log to the same run

Use `run_mode="shared"` when you need metrics from all ranks in a single view. Each rank's metrics are labeled with an `x_label` identifier.

#### Shared + global scope (1 run total)

```python {hl_lines=[1, 7]}
@wandb_init(run_mode="shared")  # All ranks log to 1 shared run
@multi_node_env.task
def train_shared_global() -> float:
    torch.distributed.init_process_group("nccl")
    # ...

    # All ranks get a run object, all log to the same run
    run = get_wandb_run()

    # Each rank logs with automatic x_label identification
    run.log({"train/loss": loss.item(), "rank": rank})
    # ...
```

#### Shared + worker scope (N runs, 1 per node)

```python {hl_lines=[1, 7, 10]}
@wandb_init(run_mode="shared", rank_scope="worker")  # 1 shared run per worker
@multi_node_env.task
def train_shared_worker() -> float:
    torch.distributed.init_process_group("nccl")
    # ...

    # All ranks get a run object, grouped by worker
    run = get_wandb_run()

    # Ranks on the same worker share a run
    run.log({"train/loss": loss.item(), "local_rank": local_rank})
    # ...
```

### New mode: Separate run per rank

Use `run_mode="new"` when you need completely separate runs per GPU. Runs are grouped in the W&B UI for easy comparison.

#### New + global scope (N×M runs, 1 group)

```python {hl_lines=[1, 7, 10]}
@wandb_init(run_mode="new")  # Each rank gets its own run, all in 1 group
@multi_node_env.task
def train_new_global() -> float:
    torch.distributed.init_process_group("nccl")
    # ...

    # Each rank has its own run
    run = get_wandb_run()

    # Run IDs: {base}-rank-{global_rank}
    run.log({"train/loss": loss.item()})
    # ...
```

#### New + worker scope (N×M runs, N groups)

```python {hl_lines=[1, 7, 10]}
@wandb_init(run_mode="new", rank_scope="worker")  # Each rank gets own run, grouped per worker
@multi_node_env.task
def train_new_worker() -> float:
    torch.distributed.init_process_group("nccl")
    # ...

    # Each rank has its own run, grouped by worker
    run = get_wandb_run()

    # Run IDs: {base}-worker-{idx}-rank-{local_rank}
    run.log({"train/loss": loss.item()})
    # ...
```

## How it works

The plugin automatically detects distributed training by checking environment variables set by distributed launchers like `torchrun`:

| Environment variable | Description                                              |
| -------------------- | -------------------------------------------------------- |
| `RANK`               | Global rank across all processes                         |
| `WORLD_SIZE`         | Total number of processes                                |
| `LOCAL_RANK`         | Rank within the current node                             |
| `LOCAL_WORLD_SIZE`   | Number of processes on the current node                  |
| `GROUP_RANK`         | Node/worker index (0 for first node, 1 for second, etc.) |

When these variables are present, the plugin:

1. **Determines which ranks should initialize W&B** based on `run_mode` and `rank_scope`
2. **Generates unique run IDs** that include worker and rank information
4. **Creates UI links** for each W&B run (single link with `rank_scope="global"`, one per worker with `rank_scope="worker"`)

The plugin automatically adapts to your training setup, eliminating the need for manual distributed configuration.

### Run ID patterns

| Scenario                     | Run ID Pattern                                | Group                    |
| ---------------------------- | --------------------------------------------- | ------------------------ |
| Single-node auto/shared      | `{base}`                                      | -                        |
| Single-node new              | `{base}-rank-{rank}`                          | `{base}`                 |
| Multi-node auto/shared (global) | `{base}`                                   | -                        |
| Multi-node auto/shared (worker) | `{base}-worker-{idx}`                      | -                        |
| Multi-node new (global)      | `{base}-rank-{global_rank}`                   | `{base}`                 |
| Multi-node new (worker)      | `{base}-worker-{idx}-rank-{local_rank}`       | `{base}-worker-{idx}`    |

Where `{base}` = `{run_name}-{action_name}`

=== PAGE: https://www.union.ai/docs/v2/union/integrations/wandb/sweeps ===

# Sweeps

W&B sweeps automate hyperparameter optimization by running multiple trials with different parameter combinations. The `@wandb_sweep` decorator creates a sweep and makes it easy to run trials in parallel using Flyte's distributed execution.

## Creating a sweep

Use `@wandb_sweep` to create a W&B sweep when the task executes:

```
import flyte
import wandb
from flyteplugins.wandb import (
    get_wandb_sweep_id,
    wandb_config,
    wandb_init,
    wandb_sweep,
    wandb_sweep_config,
)

env = flyte.TaskEnvironment(
    name="wandb-example",
    image=flyte.Image.from_debian_base(name="wandb-example").with_pip_packages(
        "flyteplugins-wandb"
    ),
    secrets=[flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY")],
)

@wandb_init
def objective():
    """Objective function that W&B calls for each trial."""
    wandb_run = wandb.run
    config = wandb_run.config

    # Simulate training with hyperparameters from the sweep
    for epoch in range(config.epochs):
        loss = 1.0 / (config.learning_rate * config.batch_size) + epoch * 0.1
        wandb_run.log({"epoch": epoch, "loss": loss})

@wandb_sweep
@env.task
async def run_sweep() -> str:
    sweep_id = get_wandb_sweep_id()

    # Run 10 trials
    wandb.agent(sweep_id, function=objective, count=10)

    return sweep_id

if __name__ == "__main__":
    flyte.init_from_config()

    r = flyte.with_runcontext(
        custom_context={
            **wandb_config(project="my-project", entity="my-team"),
            **wandb_sweep_config(
                method="random",
                metric={"name": "loss", "goal": "minimize"},
                parameters={
                    "learning_rate": {"min": 0.0001, "max": 0.1},
                    "batch_size": {"values": [16, 32, 64, 128]},
                    "epochs": {"values": [5, 10, 20]},
                },
            ),
        },
    ).run(run_sweep)

    print(f"run url: {r.url}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/wandb/sweep.py*

The `@wandb_sweep` decorator:

- Creates a W&B sweep when the task starts
- Makes the sweep ID available via `get_wandb_sweep_id()`
- Adds a link to the main sweeps page in the Flyte UI

Use `wandb_sweep_config()` to define the sweep parameters. This is passed to W&B's sweep API.

> **📝 Note**
>
> Random and Bayesian searches run indefinitely, and the sweep remains in the `Running` state until you stop it.
> You can stop a running sweep from the Weights & Biases UI or from the command line.

## Running parallel agents

Flyte's distributed execution makes it easy to run multiple sweep agents in parallel, each on its own compute resources:

```
import asyncio
from datetime import timedelta

import flyte
import wandb
from flyteplugins.wandb import (
    get_wandb_sweep_id,
    wandb_config,
    wandb_init,
    wandb_sweep,
    wandb_sweep_config,
    get_wandb_context,
)

env = flyte.TaskEnvironment(
    name="wandb-parallel-sweep-example",
    image=flyte.Image.from_debian_base(
        name="wandb-parallel-sweep-example"
    ).with_pip_packages("flyteplugins-wandb"),
    secrets=[flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY")],
)

@wandb_init
def objective():
    wandb_run = wandb.run
    config = wandb_run.config

    for epoch in range(config.epochs):
        loss = 1.0 / (config.learning_rate * config.batch_size) + epoch * 0.1
        wandb_run.log({"epoch": epoch, "loss": loss})

@wandb_sweep
@env.task
async def sweep_agent(agent_id: int, sweep_id: str, count: int = 5) -> int:
    """Single agent that runs a subset of trials."""
    wandb.agent(
        sweep_id, function=objective, count=count, project=get_wandb_context().project
    )
    return agent_id

@wandb_sweep
@env.task
async def run_parallel_sweep(total_trials: int = 20, trials_per_agent: int = 5) -> str:
    """Orchestrate multiple agents running in parallel."""
    sweep_id = get_wandb_sweep_id()

    num_agents = (total_trials + trials_per_agent - 1) // trials_per_agent

    # Launch agents in parallel, each with its own resources
    agent_tasks = [
        sweep_agent.override(
            resources=flyte.Resources(cpu="2", memory="4Gi"),
            retries=3,
            timeout=timedelta(minutes=30),
        )(agent_id=i, sweep_id=sweep_id, count=trials_per_agent)
        for i in range(num_agents)
    ]

    await asyncio.gather(*agent_tasks)
    return sweep_id

if __name__ == "__main__":
    flyte.init_from_config()

    r = flyte.with_runcontext(
        custom_context={
            **wandb_config(project="my-project", entity="my-team"),
            **wandb_sweep_config(
                method="random",
                metric={"name": "loss", "goal": "minimize"},
                parameters={
                    "learning_rate": {"min": 0.0001, "max": 0.1},
                    "batch_size": {"values": [16, 32, 64]},
                    "epochs": {"values": [5, 10, 20]},
                },
            ),
        },
    ).run(
        run_parallel_sweep,
        total_trials=20,
        trials_per_agent=5,
    )

    print(f"run url: {r.url}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/wandb/parallel_sweep.py*

This pattern provides:

- **Distributed execution**: Each agent runs on separate compute nodes
- **Resource allocation**: Specify CPU, memory, and GPU per agent
- **Fault tolerance**: Failed agents can retry without affecting others
- **Timeout protection**: Prevent runaway trials

> **📝 Note**
>
> `run_parallel_sweep` links to the main Weights & Biases sweeps page and `sweep_agent` links to the specific sweep URL because we cannot determine the sweep ID at link rendering time.

![Sweep](https://raw.githubusercontent.com/unionai/unionai-docs-static/main/images/integrations/wandb/sweep.png)

## Writing objective functions

The objective function is called by `wandb.agent()` for each trial. It must be a regular Python function decorated with `@wandb_init`:

```python {hl_lines=["1-2", "5-6"]}
@wandb_init
def objective():
    """Objective function for sweep trials."""
    # Access hyperparameters from wandb.run.config
    run = wandb.run
    config = run.config

    # Your training code
    model = create_model(
        learning_rate=config.learning_rate,
        hidden_size=config.hidden_size,
    )

    for epoch in range(config.epochs):
        train_loss = train_epoch(model)
        val_loss = validate(model)

        # Log metrics - W&B tracks these for the sweep
        run.log({
            "epoch": epoch,
            "train_loss": train_loss,
            "val_loss": val_loss,
        })

    # The final val_loss is used by the sweep to rank trials
```

Key points:

- Use `@wandb_init` on the objective function (not `@env.task`)
- Access hyperparameters via `wandb.run.config` (not `get_wandb_run()` since this is outside Flyte context)
- Log the metric specified in `wandb_sweep_config(metric=...)` so the sweep can optimize it
- The function is called multiple times by `wandb.agent()`, once per trial

=== PAGE: https://www.union.ai/docs/v2/union/integrations/wandb/downloading_logs ===

# Downloading logs

This integration enables downloading Weights & Biases run data, including metrics history, summary data, and synced files.

## Automatic download

Set `download_logs=True` to automatically download run data after your task completes:

```python {hl_lines=1}
@wandb_init(download_logs=True)
@env.task
async def train_with_download():
    run = get_wandb_run()

    for epoch in range(10):
        run.log({"loss": 1.0 / (epoch + 1)})

    return run.id
```

The downloaded data is traced by Flyte and appears as a `Dir` output in the Flyte UI. Downloaded files include:

- `summary.json`: Final summary metrics
- `metrics_history.json`: Step-by-step metrics history
- Any files synced by W&B (`requirements.txt`, `wandb_metadata.json`, etc.)

You can also set `download_logs=True` in `wandb_config()`:

```python {hl_lines=5}
flyte.with_runcontext(
    custom_context=wandb_config(
        project="my-project",
        entity="my-team",
        download_logs=True,
    ),
).run(train_task)
```

![Logs](https://raw.githubusercontent.com/unionai/unionai-docs-static/main/images/integrations/wandb/logs.png)

For sweeps, set `download_logs=True` on `@wandb_sweep` or `wandb_sweep_config()` to download all trial data:

```python {hl_lines=1}
@wandb_sweep(download_logs=True)
@env.task
async def run_sweep():
    sweep_id = get_wandb_sweep_id()
    wandb.agent(sweep_id, function=objective, count=10)
    return sweep_id
```

![Sweep Logs](https://raw.githubusercontent.com/unionai/unionai-docs-static/main/images/integrations/wandb/sweep_logs.png)

## Accessing run directories during execution

Use `get_wandb_run_dir()` to access the local W&B run directory during task execution. This is useful for writing custom files that get synced to W&B:

```python {hl_lines=[1, 7, "18-19"]}
from flyteplugins.wandb import get_wandb_run_dir

@wandb_init
@env.task
def train_with_artifacts():
    run = get_wandb_run()
    local_dir = get_wandb_run_dir()

    # Train your model
    for epoch in range(10):
        run.log({"loss": 1.0 / (epoch + 1)})

    # Save model checkpoint to the run directory
    model_path = f"{local_dir}/model_checkpoint.pt"
    torch.save(model.state_dict(), model_path)

    # Save custom metrics file
    with open(f"{local_dir}/custom_metrics.json", "w") as f:
        json.dump({"final_accuracy": 0.95}, f)

    return run.id
```

Files written to the run directory are automatically synced to W&B and can be accessed later via the W&B UI or by setting `download_logs=True`.

> **📝 Note**
>
> `get_wandb_run_dir()` accesses the local directory without making network calls. Files written here may have a brief delay before appearing in the W&B cloud.

=== PAGE: https://www.union.ai/docs/v2/union/integrations/wandb/constraints_and_best_practices ===

# Constraints and best practices

## Decorator ordering

`@wandb_init` and `@wandb_sweep` must be the **outermost decorators**, applied after `@env.task`:

```python
# Correct
@wandb_init
@env.task
async def my_task():
    ...

# Incorrect - will not work
@env.task
@wandb_init
async def my_task():
    ...
```

## Traces cannot use decorators

Do not apply `@wandb_init` to traces. Traces automatically access the parent task's run via `get_wandb_run()`:

```python
# Correct
@flyte.trace
async def my_trace():
    run = get_wandb_run()
    if run:
        run.log({"metric": 42})

# Incorrect - don't decorate traces
@wandb_init
@flyte.trace
async def my_trace():
    ...
```

## Maximum sweep agents

[W&B limits sweeps to a maximum of 20 concurrent agents](https://docs.wandb.ai/models/sweeps/existing-project#3-launch-agents).

## Configuration priority

Configuration is merged with the following priority (highest to lowest):

1. Decorator parameters (`@wandb_init(project="...")`)
2. Context manager (`with wandb_config(...)`)
3. Workflow-level context (`flyte.with_runcontext(custom_context=wandb_config(...))`)
4. Auto-generated values (run ID from Flyte context)

## Run ID generation

When no explicit `id` is provided, the plugin generates run IDs using the pattern:

```
{run_name}-{action_name}
```

This ensures unique, predictable IDs that can be matched between the `Wandb` link class and manual `wandb.init()` calls.

## Sync delay for local files

Files written to the run directory (via `get_wandb_run_dir()`) are synced to W&B asynchronously. There may be a brief delay before they appear in the W&B cloud or can be downloaded via `download_wandb_run_dir()`.

## Shared run mode requirements

When using `run_mode="shared"`, the task requires a parent task to have already created a W&B run. Calling a task with `run_mode="shared"` as a top-level task will fail.

## Objective functions for sweeps

Objective functions passed to `wandb.agent()` should:

- Be regular Python functions (not Flyte tasks)
- Be decorated with `@wandb_init`
- Access hyperparameters via `wandb.run.config` (not `get_wandb_run()`)
- Log the metric specified in `wandb_sweep_config(metric=...)` so the sweep can optimize it

## Error handling

The plugin raises standard exceptions:

- `RuntimeError`: When `download_wandb_run_dir()` is called without a run ID and no active run exists
- `wandb.errors.AuthenticationError`: When `WANDB_API_KEY` is not set or invalid
- `wandb.errors.CommError`: When a run cannot be found in the W&B cloud

=== PAGE: https://www.union.ai/docs/v2/union/integrations/wandb/manual ===

# Manual integration

If you need more control over W&B initialization, you can use the `Wandb` and `WandbSweep` link classes directly instead of the decorators. This lets you call `wandb.init()` and `wandb.finish()` yourself while still getting automatic links in the Flyte UI.

## Using the Wandb link class

Add a `Wandb` link to your task to generate a link to the W&B run in the Flyte UI:

```
import flyte
import wandb
from flyteplugins.wandb import Wandb

env = flyte.TaskEnvironment(
    name="wandb-manual-init-example",
    image=flyte.Image.from_debian_base(
        name="wandb-manual-init-example"
    ).with_pip_packages("flyteplugins-wandb"),
    secrets=[flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY")],
)

@env.task(
    links=(
        Wandb(
            project="my-project",
            entity="my-team",
            run_mode="new",
            # No id parameter - link will auto-generate from run_name-action_name
        ),
    )
)
async def train_model(learning_rate: float) -> str:
    ctx = flyte.ctx()

    # Generate run ID matching the link's auto-generated ID
    run_id = f"{ctx.action.run_name}-{ctx.action.name}"

    # Manually initialize W&B
    wandb_run = wandb.init(
        project="my-project",
        entity="my-team",
        id=run_id,
        config={"learning_rate": learning_rate},
    )

    # Your training code
    for epoch in range(10):
        loss = 1.0 / (learning_rate * (epoch + 1))
        wandb_run.log({"epoch": epoch, "loss": loss})

    # Manually finish the run
    wandb_run.finish()

    return wandb_run.id

if __name__ == "__main__":
    flyte.init_from_config()

    r = flyte.with_runcontext().run(
        train_model,
        learning_rate=0.01,
    )

    print(f"run url: {r.url}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/wandb/init_manual.py*

### With a custom run ID

If you want to use your own run ID, specify it in both the link and the `wandb.init()` call:

```python {hl_lines=[6, 14]}
@env.task(
    links=(
        Wandb(
            project="my-project",
            entity="my-team",
            id="my-custom-run-id",
        ),
    )
)
async def train_with_custom_id() -> str:
    run = wandb.init(
        project="my-project",
        entity="my-team",
        id="my-custom-run-id",  # Must match the link's ID
        resume="allow",
    )

    # Training code...
    run.finish()
    return run.id
```

### Adding links at runtime with override

You can also add links when calling a task using `.override()`:

```python {hl_lines=9}
@env.task
async def train_model(learning_rate: float) -> str:
    # ... training code with manual wandb.init() ...
    return run.id

# Add link when running the task
result = await train_model.override(
    links=(Wandb(project="my-project", entity="my-team", run_mode="new"),)
)(learning_rate=0.01)
```

## Using the `WandbSweep` link class

Use `WandbSweep` to add a link to a W&B sweep:

```
import flyte
import wandb
from flyteplugins.wandb import WandbSweep

env = flyte.TaskEnvironment(
    name="wandb-manual-sweep-example",
    image=flyte.Image.from_debian_base(
        name="wandb-manual-sweep-example"
    ).with_pip_packages("flyteplugins-wandb"),
    secrets=[flyte.Secret(key="wandb_api_key", as_env_var="WANDB_API_KEY")],
)

def objective():
    with wandb.init(project="my-project", entity="my-team") as wandb_run:
        config = wandb_run.config

        for epoch in range(config.epochs):
            loss = 1.0 / (config.learning_rate * config.batch_size) + epoch * 0.1
            wandb_run.log({"epoch": epoch, "loss": loss})

@env.task(
    links=(
        WandbSweep(
            project="my-project",
            entity="my-team",
        ),
    )
)
async def manual_sweep() -> str:
    # Manually create the sweep
    sweep_config = {
        "method": "random",
        "metric": {"name": "loss", "goal": "minimize"},
        "parameters": {
            "learning_rate": {"min": 0.0001, "max": 0.1},
            "batch_size": {"values": [16, 32, 64]},
            "epochs": {"value": 10},
        },
    }

    sweep_id = wandb.sweep(sweep_config, project="my-project", entity="my-team")

    # Run the sweep
    wandb.agent(sweep_id, function=objective, count=10, project="my-project")

    return sweep_id

if __name__ == "__main__":
    flyte.init_from_config()
    r = flyte.with_runcontext().run(manual_sweep)

    print(f"run url: {r.url}")
```

*Source: https://github.com/unionai/unionai-examples/blob/main/v2/integrations/flyte-plugins/wandb/sweep_manual.py*

The link will point to the project's sweeps page. If you have the sweep ID, you can specify it in the link:

```python {hl_lines=6}
@env.task(
    links=(
        WandbSweep(
            project="my-project",
            entity="my-team",
            id="known-sweep-id",
        ),
    )
)
async def resume_sweep() -> str:
    # Resume an existing sweep
    wandb.agent("known-sweep-id", function=objective, count=10)
    return "known-sweep-id"
```

