Skip to content

feat(sdk): add DataFlow/DataJob & improve lineage capabilities #13281

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 7 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import logging

from datahub.sdk.dataflow import DataFlow
from datahub.sdk.datajob import DataJob
from datahub.sdk.dataset import Dataset
from datahub.sdk.lineage_client import LineageClient
from datahub.sdk.main_client import DataHubClient

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def create_assets(client: DataHubClient) -> tuple:
# Create input dataset
input_dataset = Dataset(
platform="snowflake",
name="input_sales_data",
display_name="Sales Data",
env="PROD",
description="Source sales data",
schema=[
("order_id", "string"),
("customer_id", "string"),
("product_id", "string"),
("amount", "decimal"),
],
)
client.entities.upsert(input_dataset)
logger.info(f"Created input dataset: {input_dataset.display_name}")

# Create output dataset
output_dataset = Dataset(
platform="snowflake",
name="output_sales_summary",
display_name="Sales Summary",
env="PROD",
description="Processed sales summary",
schema=[
("customer_id", "string"),
("total_orders", "integer"),
("total_amount", "decimal"),
],
)
client.entities.upsert(output_dataset)
logger.info(f"Created output dataset: {output_dataset.display_name}")

# Create dataflow
processing_flow = DataFlow(
id="sales_data_processing",
name="Sales Data Processing",
description="Data flow for processing sales data",
platform="airflow",
)

client.entities.upsert(processing_flow)

# Create datajob
processing_job = DataJob(
id="process_sales_data",
flow_urn=processing_flow.urn,
name="Process Sales Data",
description="Transform sales data into summary statistics",
platform="airflow",
)
client.entities.upsert(processing_job)
logger.info(f"Created datajob: {processing_job.name}")

return input_dataset, output_dataset, processing_job


if __name__ == "__main__":
client = DataHubClient.from_env()
lineage_client = LineageClient(client=client)
input_dataset, output_dataset, processing_job = create_assets(client)
# Create lineage connections
logger.info("\n=== Creating lineage connections ===")

# Input dataset to job
lineage_client.add_datajob_lineage(
upstream=input_dataset.urn,
downstream=processing_job.urn,
)
logger.info(f"Added lineage: {input_dataset.display_name} → {processing_job.name}")

# Job to output dataset
lineage_client.add_datajob_lineage(
upstream=processing_job.urn,
downstream=output_dataset.urn,
)
logger.info(f"Added lineage: {processing_job.name} → {output_dataset.display_name}")
157 changes: 157 additions & 0 deletions metadata-ingestion/examples/ai/dh_sdk_client_dataset_lineage_sample.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import logging

from datahub.metadata.urns import DatasetUrn
from datahub.sdk.dataset import Dataset
from datahub.sdk.lineage_client import LineageClient
from datahub.sdk.main_client import DataHubClient

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


def create_test_datasets(client: DataHubClient) -> None:
# Input datasets
raw_website_logs = Dataset(
platform="snowflake",
name="raw_website_traffic_logs",
display_name="Raw Website Traffic Logs",
description="Unprocessed web traffic data for analysis",
schema=[
("log_id", "string"),
("timestamp", "date"),
("user_id", "string"),
("page_url", "string"),
("session_duration", "integer"),
("traffic_source", "string"),
],
)
client.entities.upsert(raw_website_logs)

# Processed datasets
processed_user_metrics = Dataset(
platform="snowflake",
name="processed_user_engagement_metrics",
display_name="Processed User Engagement Metrics",
description="Processed and aggregated user engagement data",
schema=[
("metric_id", "string"),
("analysis_date", "date"),
("user_segment", "string"),
("page_category", "string"),
("avg_session_duration", "decimal"),
("engagement_score", "decimal"),
],
)
client.entities.upsert(processed_user_metrics)

# Dimension dataset
user_segments_dimension = Dataset(
platform="snowflake",
name="user_segments_dimension",
display_name="User Segments Dimension",
description="Predefined user segmentation reference data",
schema=[
("segment_id", "string"),
("segment_name", "string"),
("segment_criteria", "string"),
("priority_level", "decimal"),
],
)
client.entities.upsert(user_segments_dimension)

# Final report dataset
user_engagement_summary = Dataset(
platform="snowflake",
name="user_engagement_summary_report",
display_name="User Engagement Summary Report",
description="Final aggregated user engagement summary report",
schema=[
("report_period", "date"),
("user_segment", "string"),
("total_engagement_hours", "decimal"),
("average_engagement_score", "decimal"),
],
)
client.entities.upsert(user_engagement_summary)


def test_lineage_connections(
client: DataHubClient, lineage_client: LineageClient
) -> None:
# get datasets from client
raw_website_traffic_log = client.entities.get(
DatasetUrn(name="raw_website_traffic_logs", platform="snowflake")
)
processed_user_engagement_metrics = client.entities.get(
DatasetUrn(name="processed_user_engagement_metrics", platform="snowflake")
)
user_segments_dimension = client.entities.get(
DatasetUrn(name="user_segments_dimension", platform="snowflake")
)

# 1. Basic table-level transform lineage (no column mapping)
lineage_client.add_dataset_transform_lineage(
upstream=raw_website_traffic_log.urn,
downstream=processed_user_engagement_metrics.urn,
)

# 2. Transform lineage with column mapping
column_mapping = {
"metric_id": ["log_id"],
"analysis_date": ["timestamp"],
"avg_session_duration": ["session_duration"],
}
lineage_client.add_dataset_transform_lineage(
upstream=raw_website_traffic_log.urn,
downstream=processed_user_engagement_metrics.urn,
column_lineage=column_mapping,
)

# 3. Copy lineage with auto_strict column matching
lineage_client.add_dataset_copy_lineage(
upstream=raw_website_traffic_log.urn,
downstream=processed_user_engagement_metrics.urn,
column_lineage="auto_strict",
)

# 4. Copy lineage with auto_fuzzy column matching
lineage_client.add_dataset_copy_lineage(
upstream=user_segments_dimension.urn,
downstream=processed_user_engagement_metrics.urn,
column_lineage="auto_fuzzy",
)

# 5. Copy lineage with no column lineage (table-level only)
lineage_client.add_dataset_copy_lineage(
upstream=raw_website_traffic_log.urn,
downstream=processed_user_engagement_metrics.urn,
column_lineage=None,
)

# 6. SQL-based lineage
sql_query = """
CREATE TABLE user_engagement_summary_report AS
SELECT
analysis_date AS report_period,
s.segment_name AS user_segment,
SUM(avg_session_duration) AS total_engagement_hours,
AVG(engagement_score) AS average_engagement_score
FROM processed_user_engagement_metrics m
JOIN user_segments_dimension s ON m.user_segment = s.segment_id
GROUP BY 1, 2
"""
lineage_client.add_dataset_lineage_from_sql(
query_text=sql_query, platform="snowflake"
)


if __name__ == "__main__":
client = DataHubClient.from_env()
lineage_client = LineageClient(client=client)

# Create test datasets
create_test_datasets(client)

# Test lineage connections
test_lineage_connections(client, lineage_client)
4 changes: 4 additions & 0 deletions metadata-ingestion/src/datahub/sdk/_all_entities.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from typing import Dict, List, Type

from datahub.sdk.container import Container
from datahub.sdk.dataflow import DataFlow
from datahub.sdk.datajob import DataJob
from datahub.sdk.dataset import Dataset
from datahub.sdk.entity import Entity
from datahub.sdk.mlmodel import MLModel
Expand All @@ -10,6 +12,8 @@
ENTITY_CLASSES_LIST: List[Type[Entity]] = [
Container,
Dataset,
DataFlow,
DataJob,
MLModel,
MLModelGroup,
]
Expand Down
2 changes: 2 additions & 0 deletions metadata-ingestion/src/datahub/sdk/_shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
ContainerUrn,
CorpGroupUrn,
CorpUserUrn,
DataFlowUrn,
DataJobUrn,
DataPlatformInstanceUrn,
DataPlatformUrn,
Expand All @@ -51,6 +52,7 @@
UrnOrStr: TypeAlias = Union[Urn, str]
DatasetUrnOrStr: TypeAlias = Union[str, DatasetUrn]
DatajobUrnOrStr: TypeAlias = Union[str, DataJobUrn]
DataflowUrnOrStr: TypeAlias = Union[str, DataFlowUrn]

ActorUrn: TypeAlias = Union[CorpUserUrn, CorpGroupUrn]

Expand Down
Loading
Loading