Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add OpenLineage support for CloudSQLExecuteQueryOperator #45182

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion generated/provider_dependencies.json
Original file line number Diff line number Diff line change
Expand Up @@ -631,7 +631,7 @@
"google": {
"deps": [
"PyOpenSSL>=23.0.0",
"apache-airflow-providers-common-compat>=1.3.0",
"apache-airflow-providers-common-compat>=1.4.0",
"apache-airflow-providers-common-sql>=1.20.0",
"apache-airflow>=2.9.0",
"asgiref>=3.5.2",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from __future__ import annotations

import logging
from typing import TYPE_CHECKING

log = logging.getLogger(__name__)

if TYPE_CHECKING:
from airflow.providers.openlineage.sqlparser import get_openlineage_facets_with_sql

else:
try:
from airflow.providers.openlineage.sqlparser import get_openlineage_facets_with_sql
except ImportError:

def get_openlineage_facets_with_sql(
hook,
sql: str | list[str],
conn_id: str,
database: str | None,
):
try:
from airflow.providers.openlineage.sqlparser import SQLParser
except ImportError:
log.debug("SQLParser could not be imported from OpenLineage provider.")
return None

try:
from airflow.providers.openlineage.utils.utils import should_use_external_connection

use_external_connection = should_use_external_connection(hook)
except ImportError:
# OpenLineage provider release < 1.8.0 - we always use connection
use_external_connection = True

connection = hook.get_connection(conn_id)
try:
database_info = hook.get_openlineage_database_info(connection)
except AttributeError:
log.debug("%s has no database info provided", hook)
database_info = None

if database_info is None:
return None

try:
sql_parser = SQLParser(
dialect=hook.get_openlineage_database_dialect(connection),
default_schema=hook.get_openlineage_default_schema(),
)
except AttributeError:
log.debug("%s failed to get database dialect", hook)
return None

operator_lineage = sql_parser.generate_openlineage_metadata_from_sql(
sql=sql,
hook=hook,
database_info=database_info,
database=database,
sqlalchemy_engine=hook.get_sqlalchemy_engine(),
use_connection=use_external_connection,
)

return operator_lineage


__all__ = ["get_openlineage_facets_with_sql"]
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@
from google.cloud.secretmanager_v1 import AccessSecretVersionResponse
from requests import Session

from airflow.providers.common.sql.hooks.sql import DbApiHook

UNIX_PATH_MAX = 108

# Time to sleep between active checks of the operation results
Expand Down Expand Up @@ -1146,7 +1148,7 @@ def get_sqlproxy_runner(self) -> CloudSqlProxyRunner:
gcp_conn_id=self.gcp_conn_id,
)

def get_database_hook(self, connection: Connection) -> BaseHook:
def get_database_hook(self, connection: Connection) -> DbApiHook:
"""
Retrieve database hook.
Expand All @@ -1156,7 +1158,7 @@ def get_database_hook(self, connection: Connection) -> BaseHook:
if self.database_type == "postgres":
from airflow.providers.postgres.hooks.postgres import PostgresHook

db_hook: BaseHook = PostgresHook(connection=connection, database=self.database)
db_hook: DbApiHook = PostgresHook(connection=connection, database=self.database)
else:
from airflow.providers.mysql.hooks.mysql import MySqlHook

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from __future__ import annotations

from collections.abc import Iterable, Mapping, Sequence
from contextlib import contextmanager
from functools import cached_property
from typing import TYPE_CHECKING, Any

Expand All @@ -38,8 +39,7 @@

if TYPE_CHECKING:
from airflow.models import Connection
from airflow.providers.mysql.hooks.mysql import MySqlHook
from airflow.providers.postgres.hooks.postgres import PostgresHook
from airflow.providers.openlineage.extractors import OperatorLineage
from airflow.utils.context import Context


Expand Down Expand Up @@ -1256,7 +1256,8 @@ def __init__(
self.ssl_client_key = ssl_client_key
self.ssl_secret_id = ssl_secret_id

def _execute_query(self, hook: CloudSQLDatabaseHook, database_hook: PostgresHook | MySqlHook) -> None:
@contextmanager
def cloud_sql_proxy_context(self, hook: CloudSQLDatabaseHook):
cloud_sql_proxy_runner = None
try:
if hook.use_proxy:
Expand All @@ -1266,27 +1267,27 @@ def _execute_query(self, hook: CloudSQLDatabaseHook, database_hook: PostgresHook
# be taken over here by another bind(0).
# It's quite unlikely to happen though!
cloud_sql_proxy_runner.start_proxy()
self.log.info('Executing: "%s"', self.sql)
database_hook.run(self.sql, self.autocommit, parameters=self.parameters)
yield
finally:
if cloud_sql_proxy_runner:
cloud_sql_proxy_runner.stop_proxy()

def execute(self, context: Context):
self.gcp_connection = BaseHook.get_connection(self.gcp_conn_id)

hook = self.hook
hook.validate_ssl_certs()
connection = hook.create_connection()
hook.validate_socket_path_length()
database_hook = hook.get_database_hook(connection=connection)
try:
self._execute_query(hook, database_hook)
with self.cloud_sql_proxy_context(hook):
self.log.info('Executing: "%s"', self.sql)
database_hook.run(self.sql, self.autocommit, parameters=self.parameters)
finally:
hook.cleanup_database_hook()

@cached_property
def hook(self):
self.gcp_connection = BaseHook.get_connection(self.gcp_conn_id)
return CloudSQLDatabaseHook(
gcp_cloudsql_conn_id=self.gcp_cloudsql_conn_id,
gcp_conn_id=self.gcp_conn_id,
Expand All @@ -1297,3 +1298,14 @@ def hook(self):
ssl_key=self.ssl_client_key,
ssl_secret_id=self.ssl_secret_id,
)

def get_openlineage_facets_on_complete(self, _) -> OperatorLineage | None:
from airflow.providers.common.compat.openlineage.utils.sql import get_openlineage_facets_with_sql

with self.cloud_sql_proxy_context(self.hook):
return get_openlineage_facets_with_sql(
hook=self.hook.db_hook,
sql=self.sql, # type:ignore[arg-type] # Iterable[str] instead of list[str]
conn_id=self.gcp_cloudsql_conn_id,
database=self.hook.database,
)
2 changes: 1 addition & 1 deletion providers/src/airflow/providers/google/provider.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ versions:

dependencies:
- apache-airflow>=2.9.0
- apache-airflow-providers-common-compat>=1.3.0
- apache-airflow-providers-common-compat>=1.4.0
- apache-airflow-providers-common-sql>=1.20.0
- asgiref>=3.5.2
- dill>=0.2.3
Expand Down
39 changes: 39 additions & 0 deletions providers/src/airflow/providers/openlineage/sqlparser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
from __future__ import annotations

import logging
from typing import TYPE_CHECKING, Callable

import sqlparse
Expand All @@ -30,6 +31,7 @@
create_information_schema_query,
get_table_schemas,
)
from airflow.providers.openlineage.utils.utils import should_use_external_connection
from airflow.typing_compat import TypedDict
from airflow.utils.log.logging_mixin import LoggingMixin

Expand All @@ -38,6 +40,9 @@
from sqlalchemy.engine import Engine

from airflow.hooks.base import BaseHook
from airflow.providers.common.sql.hooks.sql import DbApiHook

log = logging.getLogger(__name__)

DEFAULT_NAMESPACE = "default"
DEFAULT_INFORMATION_SCHEMA_COLUMNS = [
Expand Down Expand Up @@ -397,3 +402,37 @@ def _get_tables_hierarchy(
tables = schemas.setdefault(normalize_name(table.schema) if table.schema else None, [])
tables.append(table.name)
return hierarchy


def get_openlineage_facets_with_sql(
hook: DbApiHook, sql: str | list[str], conn_id: str, database: str | None
) -> OperatorLineage | None:
connection = hook.get_connection(conn_id)
try:
database_info = hook.get_openlineage_database_info(connection)
except AttributeError:
database_info = None

if database_info is None:
log.debug("%s has no database info provided", hook)
return None

try:
sql_parser = SQLParser(
dialect=hook.get_openlineage_database_dialect(connection),
default_schema=hook.get_openlineage_default_schema(),
)
except AttributeError:
log.debug("%s failed to get database dialect", hook)
return None

operator_lineage = sql_parser.generate_openlineage_metadata_from_sql(
sql=sql,
hook=hook,
database_info=database_info,
database=database,
sqlalchemy_engine=hook.get_sqlalchemy_engine(),
use_connection=should_use_external_connection(hook),
)

return operator_lineage
71 changes: 71 additions & 0 deletions providers/tests/google/cloud/operators/test_cloud_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,19 @@

import os
from unittest import mock
from unittest.mock import MagicMock

import pytest

from airflow.exceptions import AirflowException, TaskDeferred
from airflow.models import Connection
from airflow.providers.common.compat.openlineage.facet import (
Dataset,
SchemaDatasetFacet,
SchemaDatasetFacetFields,
SQLJobFacet,
)
from airflow.providers.common.sql.hooks.sql import DbApiHook
from airflow.providers.google.cloud.operators.cloud_sql import (
CloudSQLCloneInstanceOperator,
CloudSQLCreateInstanceDatabaseOperator,
Expand Down Expand Up @@ -822,3 +830,66 @@ def test_create_operator_with_too_long_unix_socket_path(self, get_connection):
operator.execute(None)
err = ctx.value
assert "The UNIX socket path length cannot exceed" in str(err)

@pytest.mark.parametrize(
"connection_port, default_port, expected_port",
[(None, 4321, 4321), (1234, None, 1234), (1234, 4321, 1234)],
)
def test_execute_openlineage_events(self, connection_port, default_port, expected_port):
class DBApiHookForTests(DbApiHook):
conn_name_attr = "sql_default"
get_conn = MagicMock(name="conn")
get_connection = MagicMock()

def get_openlineage_database_info(self, connection):
from airflow.providers.openlineage.sqlparser import DatabaseInfo

return DatabaseInfo(
scheme="sqlscheme",
authority=DbApiHook.get_openlineage_authority_part(connection, default_port=default_port),
)

dbapi_hook = DBApiHookForTests()

class CloudSQLExecuteQueryOperatorForTest(CloudSQLExecuteQueryOperator):
@property
def hook(self):
return MagicMock(db_hook=dbapi_hook, database="")

sql = """CREATE TABLE IF NOT EXISTS popular_orders_day_of_week (
order_day_of_week VARCHAR(64) NOT NULL,
order_placed_on TIMESTAMP NOT NULL,
orders_placed INTEGER NOT NULL
);
FORGOT TO COMMENT"""
op = CloudSQLExecuteQueryOperatorForTest(task_id="task_id", sql=sql)
DB_SCHEMA_NAME = "PUBLIC"
rows = [
(DB_SCHEMA_NAME, "popular_orders_day_of_week", "order_day_of_week", 1, "varchar"),
(DB_SCHEMA_NAME, "popular_orders_day_of_week", "order_placed_on", 2, "timestamp"),
(DB_SCHEMA_NAME, "popular_orders_day_of_week", "orders_placed", 3, "int4"),
]
dbapi_hook.get_connection.return_value = Connection(
conn_id="sql_default", conn_type="postgresql", host="host", port=connection_port
)
dbapi_hook.get_conn.return_value.cursor.return_value.fetchall.side_effect = [rows, []]

lineage = op.get_openlineage_facets_on_complete(None)
assert len(lineage.inputs) == 0
assert lineage.job_facets == {"sql": SQLJobFacet(query=sql)}
assert lineage.run_facets["extractionError"].failedTasks == 1
assert lineage.outputs == [
Dataset(
namespace=f"sqlscheme://host:{expected_port}",
name="PUBLIC.popular_orders_day_of_week",
facets={
"schema": SchemaDatasetFacet(
fields=[
SchemaDatasetFacetFields(name="order_day_of_week", type="varchar"),
SchemaDatasetFacetFields(name="order_placed_on", type="timestamp"),
SchemaDatasetFacetFields(name="orders_placed", type="int4"),
]
)
},
)
]
Loading