Skip to content

Commit

Permalink
feat: Add OpenLineage support for some SQL to GCS operators
Browse files Browse the repository at this point in the history
Signed-off-by: Kacper Muda <[email protected]>
  • Loading branch information
kacpermuda committed Dec 27, 2024
1 parent 60cd5ad commit 73b4ad9
Show file tree
Hide file tree
Showing 8 changed files with 216 additions and 8 deletions.
2 changes: 1 addition & 1 deletion generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@
"google": {
"deps": [
"PyOpenSSL>=23.0.0",
"apache-airflow-providers-common-compat>=1.3.0",
"apache-airflow-providers-common-compat>=1.4.0",
"apache-airflow-providers-common-sql>=1.20.0",
"apache-airflow>=2.9.0",
"asgiref>=3.5.2",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import logging
from typing import TYPE_CHECKING

log = logging.getLogger(__name__)

if TYPE_CHECKING:
from airflow.providers.openlineage.sqlparser import get_openlineage_facets_with_sql

else:
try:
from airflow.providers.openlineage.sqlparser import get_openlineage_facets_with_sql
except ImportError:

def get_openlineage_facets_with_sql(
hook,
sql: str | list[str],
conn_id: str,
database: str | None,
):
try:
from airflow.providers.openlineage.sqlparser import SQLParser
except ImportError:
log.debug("SQLParser could not be imported from OpenLineage provider.")
return None

try:
from airflow.providers.openlineage.utils.utils import should_use_external_connection

use_external_connection = should_use_external_connection(hook)
except ImportError:
# OpenLineage provider release < 1.8.0 - we always use connection
use_external_connection = True

connection = hook.get_connection(conn_id)
try:
database_info = hook.get_openlineage_database_info(connection)
except AttributeError:
log.debug("%s has no database info provided", hook)
database_info = None

if database_info is None:
return None

try:
sql_parser = SQLParser(
dialect=hook.get_openlineage_database_dialect(connection),
default_schema=hook.get_openlineage_default_schema(),
)
except AttributeError:
log.debug("%s failed to get database dialect", hook)
return None

operator_lineage = sql_parser.generate_openlineage_metadata_from_sql(
sql=sql,
hook=hook,
database_info=database_info,
database=database,
sqlalchemy_engine=hook.get_sqlalchemy_engine(),
use_connection=use_external_connection,
)

return operator_lineage


__all__ = ["get_openlineage_facets_with_sql"]
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import base64
from datetime import date, datetime, time, timedelta
from decimal import Decimal
from functools import cached_property
from typing import TYPE_CHECKING

try:
from MySQLdb.constants import FIELD_TYPE
Expand All @@ -37,6 +39,9 @@
from airflow.providers.google.cloud.transfers.sql_to_gcs import BaseSQLToGCSOperator
from airflow.providers.mysql.hooks.mysql import MySqlHook

if TYPE_CHECKING:
from airflow.providers.openlineage.extractors import OperatorLineage


class MySQLToGCSOperator(BaseSQLToGCSOperator):
"""
Expand Down Expand Up @@ -77,10 +82,13 @@ def __init__(self, *, mysql_conn_id="mysql_default", ensure_utc=False, **kwargs)
self.mysql_conn_id = mysql_conn_id
self.ensure_utc = ensure_utc

@cached_property
def db_hook(self) -> MySqlHook:
return MySqlHook(mysql_conn_id=self.mysql_conn_id)

def query(self):
"""Query mysql and returns a cursor to the results."""
mysql = MySqlHook(mysql_conn_id=self.mysql_conn_id)
conn = mysql.get_conn()
conn = self.db_hook.get_conn()
cursor = conn.cursor()
if self.ensure_utc:
# Ensure TIMESTAMP results are in UTC
Expand Down Expand Up @@ -140,3 +148,20 @@ def convert_type(self, value, schema_type: str, **kwargs):
else:
value = base64.standard_b64encode(value).decode("ascii")
return value

def get_openlineage_facets_on_start(self) -> OperatorLineage | None:
from airflow.providers.common.compat.openlineage.facet import SQLJobFacet
from airflow.providers.common.compat.openlineage.utils.sql import get_openlineage_facets_with_sql
from airflow.providers.openlineage.extractors import OperatorLineage

sql_parsing_result = get_openlineage_facets_with_sql(
hook=self.db_hook,
sql=self.sql,
conn_id=self.mysql_conn_id,
database=None,
)
gcs_output_datasets = self._get_openlineage_output_datasets()
if sql_parsing_result:
sql_parsing_result.outputs = gcs_output_datasets
return sql_parsing_result
return OperatorLineage(outputs=gcs_output_datasets, job_facets={"sql": SQLJobFacet(self.sql)})
Original file line number Diff line number Diff line change
Expand Up @@ -24,13 +24,18 @@
import time
import uuid
from decimal import Decimal
from functools import cached_property
from typing import TYPE_CHECKING

import pendulum
from slugify import slugify

from airflow.providers.google.cloud.transfers.sql_to_gcs import BaseSQLToGCSOperator
from airflow.providers.postgres.hooks.postgres import PostgresHook

if TYPE_CHECKING:
from airflow.providers.openlineage.extractors import OperatorLineage


class _PostgresServerSideCursorDecorator:
"""
Expand Down Expand Up @@ -132,10 +137,13 @@ def _unique_name(self):
)
return None

@cached_property
def db_hook(self) -> PostgresHook:
return PostgresHook(postgres_conn_id=self.postgres_conn_id)

def query(self):
"""Query Postgres and returns a cursor to the results."""
hook = PostgresHook(postgres_conn_id=self.postgres_conn_id)
conn = hook.get_conn()
conn = self.db_hook.get_conn()
cursor = conn.cursor(name=self._unique_name())
cursor.execute(self.sql, self.parameters)
if self.use_server_side_cursor:
Expand Down Expand Up @@ -180,3 +188,20 @@ def convert_type(self, value, schema_type, stringify_dict=True):
if isinstance(value, Decimal):
return float(value)
return value

def get_openlineage_facets_on_start(self) -> OperatorLineage | None:
from airflow.providers.common.compat.openlineage.facet import SQLJobFacet
from airflow.providers.common.compat.openlineage.utils.sql import get_openlineage_facets_with_sql
from airflow.providers.openlineage.extractors import OperatorLineage

sql_parsing_result = get_openlineage_facets_with_sql(
hook=self.db_hook,
sql=self.sql,
conn_id=self.postgres_conn_id,
database=self.db_hook.database,
)
gcs_output_datasets = self._get_openlineage_output_datasets()
if sql_parsing_result:
sql_parsing_result.outputs = gcs_output_datasets
return sql_parsing_result
return OperatorLineage(outputs=gcs_output_datasets, job_facets={"sql": SQLJobFacet(self.sql)})
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from airflow.providers.google.cloud.hooks.gcs import GCSHook

if TYPE_CHECKING:
from airflow.providers.common.compat.openlineage.facet import OutputDataset
from airflow.utils.context import Context


Expand Down Expand Up @@ -151,6 +152,7 @@ def __init__(
self.partition_columns = partition_columns
self.write_on_empty = write_on_empty
self.parquet_row_group_size = parquet_row_group_size
self._uploaded_file_names: list[str] = []

def execute(self, context: Context):
if self.partition_columns:
Expand Down Expand Up @@ -501,3 +503,13 @@ def _upload_to_gcs(self, file_to_upload):
gzip=self.gzip if is_data_file else False,
metadata=metadata,
)
self._uploaded_file_names.append(object_name)

def _get_openlineage_output_datasets(self) -> list[OutputDataset]:
"""Retrieve OpenLineage output datasets."""
from airflow.providers.common.compat.openlineage.facet import OutputDataset
from airflow.providers.google.cloud.openlineage.utils import extract_ds_name_from_gcs_path

return [
OutputDataset(namespace=f"gs://{self.bucket}", name=extract_ds_name_from_gcs_path(self.filename))
]
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

from functools import cached_property
from typing import TYPE_CHECKING, Any

from airflow.providers.google.cloud.transfers.sql_to_gcs import BaseSQLToGCSOperator
Expand All @@ -26,6 +27,8 @@
from trino.client import TrinoResult
from trino.dbapi import Cursor as TrinoCursor

from airflow.providers.openlineage.extractors import OperatorLineage


class _TrinoToGCSTrinoCursorAdapter:
"""
Expand Down Expand Up @@ -181,10 +184,13 @@ def __init__(self, *, trino_conn_id: str = "trino_default", **kwargs):
super().__init__(**kwargs)
self.trino_conn_id = trino_conn_id

@cached_property
def db_hook(self) -> TrinoHook:
return TrinoHook(trino_conn_id=self.trino_conn_id)

def query(self):
"""Query trino and returns a cursor to the results."""
trino = TrinoHook(trino_conn_id=self.trino_conn_id)
conn = trino.get_conn()
conn = self.db_hook.get_conn()
cursor = conn.cursor()
self.log.info("Executing: %s", self.sql)
cursor.execute(self.sql)
Expand All @@ -207,3 +213,20 @@ def convert_type(self, value, schema_type, **kwargs):
:param schema_type: BigQuery data type
"""
return value

def get_openlineage_facets_on_start(self) -> OperatorLineage | None:
from airflow.providers.common.compat.openlineage.facet import SQLJobFacet
from airflow.providers.common.compat.openlineage.utils.sql import get_openlineage_facets_with_sql
from airflow.providers.openlineage.extractors import OperatorLineage

sql_parsing_result = get_openlineage_facets_with_sql(
hook=self.db_hook,
sql=self.sql,
conn_id=self.trino_conn_id,
database=None,
)
gcs_output_datasets = self._get_openlineage_output_datasets()
if sql_parsing_result:
sql_parsing_result.outputs = gcs_output_datasets
return sql_parsing_result
return OperatorLineage(outputs=gcs_output_datasets, job_facets={"sql": SQLJobFacet(self.sql)})
2 changes: 1 addition & 1 deletion providers/src/airflow/providers/google/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ versions:

dependencies:
- apache-airflow>=2.9.0
- apache-airflow-providers-common-compat>=1.3.0
- apache-airflow-providers-common-compat>=1.4.0
- apache-airflow-providers-common-sql>=1.20.0
- asgiref>=3.5.2
- dill>=0.2.3
Expand Down
39 changes: 39 additions & 0 deletions providers/src/airflow/providers/openlineage/sqlparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Callable

import sqlparse
Expand All @@ -30,6 +31,7 @@
create_information_schema_query,
get_table_schemas,
)
from airflow.providers.openlineage.utils.utils import should_use_external_connection
from airflow.typing_compat import TypedDict
from airflow.utils.log.logging_mixin import LoggingMixin

Expand All @@ -38,6 +40,9 @@
from sqlalchemy.engine import Engine

from airflow.hooks.base import BaseHook
from airflow.providers.common.sql.hooks.sql import DbApiHook

log = logging.getLogger(__name__)

DEFAULT_NAMESPACE = "default"
DEFAULT_INFORMATION_SCHEMA_COLUMNS = [
Expand Down Expand Up @@ -397,3 +402,37 @@ def _get_tables_hierarchy(
tables = schemas.setdefault(normalize_name(table.schema) if table.schema else None, [])
tables.append(table.name)
return hierarchy


def get_openlineage_facets_with_sql(
hook: DbApiHook, sql: str | list[str], conn_id: str, database: str | None
) -> OperatorLineage | None:
connection = hook.get_connection(conn_id)
try:
database_info = hook.get_openlineage_database_info(connection)
except AttributeError:
database_info = None

if database_info is None:
log.debug("%s has no database info provided", hook)
return None

try:
sql_parser = SQLParser(
dialect=hook.get_openlineage_database_dialect(connection),
default_schema=hook.get_openlineage_default_schema(),
)
except AttributeError:
log.debug("%s failed to get database dialect", hook)
return None

operator_lineage = sql_parser.generate_openlineage_metadata_from_sql(
sql=sql,
hook=hook,
database_info=database_info,
database=database,
sqlalchemy_engine=hook.get_sqlalchemy_engine(),
use_connection=should_use_external_connection(hook),
)

return operator_lineage

0 comments on commit 73b4ad9

Please sign in to comment.