diff --git a/import-automation/workflow/embedding-helper/Dockerfile b/import-automation/workflow/embedding-helper/Dockerfile new file mode 100644 index 0000000000..2f61d779ec --- /dev/null +++ b/import-automation/workflow/embedding-helper/Dockerfile @@ -0,0 +1,10 @@ +FROM python:3.10-slim + +WORKDIR /app + +COPY requirements.txt . +RUN pip install -r requirements.txt + +COPY . . + +CMD ["python", "main.py"] diff --git a/import-automation/workflow/embedding-helper/embedding_utils.py b/import-automation/workflow/embedding-helper/embedding_utils.py new file mode 100644 index 0000000000..ce7b4762c5 --- /dev/null +++ b/import-automation/workflow/embedding-helper/embedding_utils.py @@ -0,0 +1,169 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helper utilities for embedding workflows.""" + +import itertools +import logging +import time +from datetime import datetime +from google.cloud.spanner_v1.param_types import TIMESTAMP, STRING, Array, Struct, StructField + + +_BATCH_SIZE = 500 + +def get_latest_lock_timestamp(database): + """Gets the latest AcquiredTimestamp from IngestionLock table. + + Args: + database: google.cloud.spanner.Database object. + + Returns: + The latest AcquiredTimestamp as a datetime object, or None if no entries exist. + """ + time_lock_sql = "SELECT MAX(AcquiredTimestamp) FROM IngestionLock" + try: + with database.snapshot() as snapshot: + results = snapshot.execute_sql(time_lock_sql) + for row in results: + return row[0] + except Exception as e: + logging.error(f"Error fetching latest lock timestamp: {e}") + raise + return None + +def get_updated_nodes(database, timestamp, node_types): + """Gets subject_ids and names from Node table where update_timestamp > timestamp. + Yields results to avoid loading all into memory. + + Args: + database: google.cloud.spanner.Database object. + timestamp: datetime object to filter by. + node_types: A list of strings representing the node types to filter by. + + Yields: + Dictionaries containing subject_id and name. + """ + timestamp_condition = "update_timestamp > @timestamp" if timestamp else "TRUE" + + updated_node_sql = f""" + SELECT subject_id, name, types FROM Node + WHERE name IS NOT NULL + AND {timestamp_condition} + AND EXISTS ( + SELECT 1 FROM UNNEST(types) AS t WHERE t IN UNNEST(@node_types) + ) + """ + + params = {"node_types": node_types} + param_types = {"node_types": Array(STRING)} + + if timestamp: + logging.info(f"Filtering valid nodes updated after {timestamp}") + params["timestamp"] = timestamp + param_types["timestamp"] = TIMESTAMP + else: + logging.info("No timestamp provided, reading all valid nodes.") + + try: + with database.snapshot() as snapshot: + results = snapshot.execute_sql(updated_node_sql, params=params, param_types=param_types) + fields = None + for row in results: + if fields is None: + fields = [field.name for field in results.fields] + yield dict(zip(fields, row)) + except Exception as e: + logging.error(f"Error fetching updated nodes: {e}") + raise + + +def filter_and_convert_nodes(nodes_generator): + """Filters out nodes without a name and converts dictionaries to tuples. + Reads from a generator and yields results. + + Args: + nodes_generator: A generator yielding dictionaries containing subject_id, name, and types. + + Yields: + Tuples (subject_id, embedding_content, types). + """ + for node in nodes_generator: + if node.get("name"): + yield (node.get("subject_id"), node.get("name"), node.get("types")) + + +def generate_embeddings_partitioned(database, nodes_generator): + """Generates embeddings in batches using standard transactions. + Processes nodes in chunks of 500 to avoid transaction size limits. + Accepts a generator to avoid loading all nodes into memory. + + Args: + database: google.cloud.spanner.Database object. + nodes_generator: A generator yielding tuples containing (subject_id, embedding_content). + + Returns: + The number of affected rows. + """ + global _BATCH_SIZE + total_rows_affected = 0 + + logging.info(f"Generating embeddings in batches of {_BATCH_SIZE}.") + + embeddings_sql = """ + INSERT OR UPDATE INTO NodeEmbeddings (subject_id, embedding_content, embeddings, types) + SELECT subject_id, content, embeddings.values, types + FROM ML.PREDICT( + MODEL text_embeddings, + (SELECT subject_id, embedding_content AS content, types, "RETRIEVAL_QUERY" AS task_type FROM UNNEST(@nodes)) + ) + """ + + struct_type = Struct([ + StructField("subject_id", STRING), + StructField("embedding_content", STRING), + StructField("types", Array(STRING)) + ]) + + def chunked(iterable, n): + it = iter(iterable) + while True: + chunk = list(itertools.islice(it, n)) + if not chunk: + break + yield chunk + + for batch in chunked(nodes_generator, _BATCH_SIZE): + params = {"nodes": batch} + param_types = {"nodes": Array(struct_type)} + + def _execute_dml(transaction): + return transaction.execute_update(embeddings_sql, params=params, param_types=param_types, timeout=300) + + try: + row_count = database.run_in_transaction(_execute_dml) + total_rows_affected += row_count + logging.info(f"Processed batch of {len(batch)} nodes. Affected {row_count} rows.") + time.sleep(0.5) + except Exception as e: + logging.error(f"Error executing batch transaction: {e}") + raise + + logging.info(f"Completed batch processing. Total affected rows: {total_rows_affected}") + return total_rows_affected + + + + + diff --git a/import-automation/workflow/embedding-helper/main.py b/import-automation/workflow/embedding-helper/main.py new file mode 100644 index 0000000000..a0247027b8 --- /dev/null +++ b/import-automation/workflow/embedding-helper/main.py @@ -0,0 +1,55 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import logging +from google.cloud import spanner +from embedding_utils import get_latest_lock_timestamp, get_updated_nodes, filter_and_convert_nodes, generate_embeddings_partitioned + +logging.basicConfig(level=logging.INFO) + +def main(): + # Read configuration from environment variables + instance_id = os.environ.get("SPANNER_INSTANCE") + database_id = os.environ.get("SPANNER_DATABASE") + project_id = os.environ.get("SPANNER_PROJECT") + + if not instance_id or not database_id: + logging.error("SPANNER_INSTANCE or SPANNER_DATABASE environment variables not set.") + exit(1) + + logging.info(f"Connecting to Spanner instance: {instance_id}, database: {database_id}, project: {project_id}") + + spanner_client = spanner.Client(project=project_id) + instance = spanner_client.instance(instance_id) + database = instance.database(database_id) + + node_types = ["StatisticalVariable", "Topic"] + + try: + logging.info(f"Job started. Fetching all nodes for types: {node_types}") + timestamp = get_latest_lock_timestamp(database) + nodes = get_updated_nodes(database, timestamp, node_types) + + converted_nodes = filter_and_convert_nodes(nodes) + + affected_rows = generate_embeddings_partitioned(database, converted_nodes) + + logging.info(f"Job completed successfully. Total affected rows: {affected_rows}") + except Exception as e: + logging.error(f"Job failed with error: {e}") + exit(1) + +if __name__ == "__main__": + main() diff --git a/import-automation/workflow/embedding-helper/requirements.txt b/import-automation/workflow/embedding-helper/requirements.txt new file mode 100644 index 0000000000..58cfe74bd9 --- /dev/null +++ b/import-automation/workflow/embedding-helper/requirements.txt @@ -0,0 +1,3 @@ +functions-framework==3.* +google-cloud-spanner +google-auth diff --git a/import-automation/workflow/embedding-helper/test/embedding_utils_test.py b/import-automation/workflow/embedding-helper/test/embedding_utils_test.py new file mode 100644 index 0000000000..9e261ea627 --- /dev/null +++ b/import-automation/workflow/embedding-helper/test/embedding_utils_test.py @@ -0,0 +1,111 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import MagicMock, patch +from datetime import datetime +import sys +import os + +# Add parent directory of current file (src directory) to the path +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from embedding_utils import ( + get_latest_lock_timestamp, + get_updated_nodes, + filter_and_convert_nodes, + generate_embeddings_partitioned +) + +class TestEmbeddingUtils(unittest.TestCase): + + def test_get_latest_lock_timestamp(self): + mock_database = MagicMock() + mock_snapshot = MagicMock() + mock_database.snapshot.return_value.__enter__.return_value = mock_snapshot + expected_timestamp = datetime(2026, 4, 20, 12, 0, 0) + mock_snapshot.execute_sql.return_value = [(expected_timestamp,)] + + timestamp = get_latest_lock_timestamp(mock_database) + self.assertEqual(timestamp, expected_timestamp) + + def test_get_updated_nodes(self): + mock_database = MagicMock() + mock_snapshot = MagicMock() + mock_database.snapshot.return_value.__enter__.return_value = mock_snapshot + + class MockField: + def __init__(self, name): + self.name = name + + class MockResults: + def __init__(self, rows, field_names): + self.rows = rows + self.fields = [MockField(name) for name in field_names] + + def __iter__(self): + return iter(self.rows) + + mock_snapshot.execute_sql.return_value = MockResults( + rows=[("dc/1", "Node 1", ["Topic"])], + field_names=["subject_id", "name", "types"] + ) + + nodes = list(get_updated_nodes(mock_database, None, ["Topic"])) + self.assertEqual(len(nodes), 1) + self.assertEqual(nodes[0]["subject_id"], "dc/1") + self.assertEqual(nodes[0]["name"], "Node 1") + self.assertEqual(nodes[0]["types"], ["Topic"]) + + def test_filter_and_convert_nodes(self): + nodes = [ + {"subject_id": "dc/1", "name": "Node 1", "types": ["Topic"]}, + {"subject_id": "dc/2", "name": None, "types": ["StatisticalVariable"]}, + {"subject_id": "dc/3", "name": "Node 3", "types": ["Topic", "StatisticalVariable"]}, + {"subject_id": "dc/4", "name": "", "types": ["StatisticalVariable"]} + ] + + converted = list(filter_and_convert_nodes(nodes)) + self.assertEqual(len(converted), 2) + self.assertEqual(converted[0], ("dc/1", "Node 1", ["Topic"])) + self.assertEqual(converted[1], ("dc/3", "Node 3", ["Topic", "StatisticalVariable"])) + + @patch('embedding_utils._BATCH_SIZE', 2) + def test_generate_embeddings_partitioned(self): + mock_database = MagicMock() + + nodes = [ + ("dc/1", "Node 1", ["Topic"]), + ("dc/2", "Node 2", ["Topic"]), + ("dc/3", "Node 3", ["Topic"]), + ("dc/4", "Node 4", ["Topic"]), + ("dc/5", "Node 5", ["Topic"]), + ("dc/6", "Node 6", ["Topic"]), + ("dc/7", "Node 7", ["Topic"]), + ("dc/8", "Node 8", ["Topic"]) + ] + + def side_effect(func): + mock_transaction = MagicMock() + mock_transaction.execute_update.return_value = 2 + return func(mock_transaction) + + mock_database.run_in_transaction.side_effect = side_effect + + affected_rows = generate_embeddings_partitioned(mock_database, nodes) + self.assertEqual(affected_rows, 8) + self.assertEqual(mock_database.run_in_transaction.call_count, 4) + +if __name__ == '__main__': + unittest.main() diff --git a/import-automation/workflow/embedding-helper/test/main_test.py b/import-automation/workflow/embedding-helper/test/main_test.py new file mode 100644 index 0000000000..6c3f0b3b94 --- /dev/null +++ b/import-automation/workflow/embedding-helper/test/main_test.py @@ -0,0 +1,61 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import MagicMock, patch +import sys +import os + +from datetime import datetime +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import main + +class TestMain(unittest.TestCase): + + @patch.dict(os.environ, { + "SPANNER_INSTANCE": "test-instance", + "SPANNER_DATABASE": "test-db", + "SPANNER_PROJECT": "test-proj" + }) + @patch('main.spanner.Client') + @patch('main.get_latest_lock_timestamp') + @patch('main.get_updated_nodes') + @patch('main.filter_and_convert_nodes') + @patch('main.generate_embeddings_partitioned') + def test_main_e2e_success(self, mock_generate, mock_filter, mock_nodes, mock_timestamp, mock_spanner): + mock_database = MagicMock() + mock_instance = MagicMock() + mock_instance.database.return_value = mock_database + mock_spanner.return_value.instance.return_value = mock_instance + + timestamp_val = datetime(2026, 4, 20, 12, 0, 0) + mock_timestamp.return_value = timestamp_val + mock_nodes.return_value = [{"subject_id": "dc/1", "name": "Node 1", "types": ["Topic"]}] + mock_filter.return_value = [("dc/1", "Node 1", ["Topic"])] + mock_generate.return_value = 1 + + try: + main.main() + except SystemExit as e: + self.assertEqual(e.code, 0) + + mock_spanner.assert_called_once_with(project="test-proj") + mock_timestamp.assert_called_once_with(mock_database) + mock_nodes.assert_called_once_with(mock_database, timestamp_val, ["StatisticalVariable", "Topic"]) + mock_filter.assert_called_once() + mock_generate.assert_called_once() + +if __name__ == '__main__': + unittest.main()