From 53b551b499539bfbb63a5ed9176ed5799b3504da Mon Sep 17 00:00:00 2001 From: shixiao-coder Date: Fri, 17 Apr 2026 12:41:17 -0400 Subject: [PATCH 1/7] V0 version of embedding ingestion core. It includes: requirements.txt to install the related packages embedding_utils.py including function to: - read the latest timestamp for the lock, if no timestamp set None - function to find Node IDs to update based on timestamp, nodetype, validity by none empty - function to convert the Node and fields to ID and embedding_content - function to generate embeddings from the ID and embedding content in batch --- .../embedding-helper/embedding_utils.py | 169 ++++++++++++++++++ .../workflow/embedding-helper/main.py | 54 ++++++ .../embedding-helper/requirements.txt | 2 + 3 files changed, 225 insertions(+) create mode 100644 import-automation/workflow/embedding-helper/embedding_utils.py create mode 100644 import-automation/workflow/embedding-helper/main.py create mode 100644 import-automation/workflow/embedding-helper/requirements.txt diff --git a/import-automation/workflow/embedding-helper/embedding_utils.py b/import-automation/workflow/embedding-helper/embedding_utils.py new file mode 100644 index 0000000000..ff506bf078 --- /dev/null +++ b/import-automation/workflow/embedding-helper/embedding_utils.py @@ -0,0 +1,169 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Helper utilities for embedding workflows.""" + +import logging +from datetime import datetime +from google.cloud.spanner_v1.param_types import TIMESTAMP, STRING, Array, Struct, Struct, StructField + + +def get_latest_lock_timestamp(database): + """Gets the latest AcquiredTimestamp from IngestionLock table. + + Args: + database: google.cloud.spanner.Database object. + + Returns: + The latest AcquiredTimestamp as a datetime object, or None if no entries exist. + """ + sql = "SELECT MAX(AcquiredTimestamp) FROM IngestionLock" + try: + with database.snapshot() as snapshot: + results = snapshot.execute_sql(sql) + for row in results: + return row[0] + except Exception as e: + logging.error(f"Error fetching latest lock timestamp: {e}") + raise + return None + +def get_updated_nodes(database, timestamp, node_types): + """Gets subject_ids and names from Node table where update_timestamp > timestamp. + + Args: + database: google.cloud.spanner.Database object. + timestamp: datetime object to filter by. + node_types: A list of strings representing the node types to filter by. + + Returns: + A list of dictionaries containing subject_id and name. + """ + if not timestamp: + logging.info("No timestamp provided, reading all nodes.") + sql = """ + SELECT subject_id, name FROM Node + WHERE name IS NOT NULL + AND EXISTS ( + SELECT 1 FROM UNNEST(types) AS t WHERE t IN UNNEST(@node_types) + ) + """ + params = {"node_types": node_types} + param_types = {"node_types": Array(STRING)} + else: + logging.info(f"Filtering nodes updated after {timestamp}") + sql = """ + SELECT subject_id, name FROM Node + WHERE update_timestamp > @timestamp + AND name IS NOT NULL + AND EXISTS ( + SELECT 1 FROM UNNEST(types) AS t WHERE t IN UNNEST(@node_types) + ) + """ + params = {"timestamp": timestamp, "node_types": node_types} + param_types = {"timestamp": TIMESTAMP, "node_types": Array(STRING)} + + nodes = [] + try: + with database.snapshot() as snapshot: + results = snapshot.execute_sql(sql, params=params, param_types=param_types) + fields = None + for row in results: + if fields is None: + fields = [field.name for field in results.fields] + nodes.append(dict(zip(fields, row))) + except Exception as e: + logging.error(f"Error fetching updated nodes: {e}") + raise + return nodes + + +def filter_and_convert_nodes(nodes): + """Filters out nodes without a name and converts dictionaries to tuples. + Reads 'name' from input and maps it to output tuple. + + Args: + nodes: A list of dictionaries containing subject_id and name. + + Returns: + A list of tuples (subject_id, embedding_content). + """ + valid_tuples = [] + valid_tuples = [ + (node.get("subject_id"), node.get("name")) + for node in nodes + if node.get("name") + ] + return valid_tuples + + +def generate_embeddings_partitioned(database, nodes): + """Generates embeddings in batches using standard transactions. + Processes nodes in chunks of 500 to avoid transaction size limits. + + Args: + database: google.cloud.spanner.Database object. + nodes: A list of tuples containing (subject_id, embedding_content). + + Returns: + The number of affected rows. + """ + if not nodes: + logging.info("No nodes to update.") + return 0 + + BATCH_SIZE = 100 + total_rows_affected = 0 + + logging.info(f"Generating embeddings for {len(nodes)} nodes in batches of {BATCH_SIZE}.") + + sql = """ + INSERT OR UPDATE INTO NodeEmbeddings (subject_id, embedding_content, embeddings) + SELECT subject_id, content, embeddings.values + FROM ML.PREDICT( + MODEL text_embeddings, + (SELECT subject_id, embedding_content AS content, "RETRIEVAL_QUERY" AS task_type FROM UNNEST(@nodes)) + ) + """ + + struct_type = Struct([ + StructField("subject_id", STRING), + StructField("embedding_content", STRING) + ]) + + for i in range(0, len(nodes), BATCH_SIZE): + batch = nodes[i : i + BATCH_SIZE] + + params = {"nodes": batch} + param_types = {"nodes": Array(struct_type)} + + def _execute_dml(transaction): + return transaction.execute_update(sql, params=params, param_types=param_types) + + try: + row_count = database.run_in_transaction(_execute_dml) + total_rows_affected += row_count + logging.info(f"Processed batch of {len(batch)} nodes. Affected {row_count} rows.") + time.sleep(0.5) + except Exception as e: + logging.error(f"Error executing batch transaction: {e}") + raise + + logging.info(f"Completed batch processing. Total affected rows: {total_rows_affected}") + return total_rows_affected + + + + + diff --git a/import-automation/workflow/embedding-helper/main.py b/import-automation/workflow/embedding-helper/main.py new file mode 100644 index 0000000000..a7f0ccf9ba --- /dev/null +++ b/import-automation/workflow/embedding-helper/main.py @@ -0,0 +1,54 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import functions_framework +import logging +from flask import jsonify + +logging.getLogger().setLevel(logging.INFO) + +@functions_framework.http +def embedding_helper(request): + """ + HTTP Cloud Function for handling embedding tasks. + Takes request and argument actionType. + """ + request_json = request.get_json(silent=True) + if not request_json: + return ('Request is not a valid JSON', 400) + + if 'actionType' not in request_json: + return ("'actionType' parameter is missing", 400) + + action_type = request_json['actionType'] + logging.info(f"Received request for actionType: {action_type}") + + if action_type == 'initilization': + return jsonify({ + "status": "success", + "message": "initilization action triggered" + }), 200 + elif action_type == 'incremental_update': + return jsonify({ + "status": "success", + "message": "incremental_update action triggered" + }), 200 + elif action_type == 'manual_update': + return jsonify({ + "status": "success", + "message": "manual_update action triggered" + }), 200 + else: + logging.warning(f"Unknown actionType: {action_type}") + return (f"Unknown actionType: {action_type}", 400) diff --git a/import-automation/workflow/embedding-helper/requirements.txt b/import-automation/workflow/embedding-helper/requirements.txt new file mode 100644 index 0000000000..754e2bd573 --- /dev/null +++ b/import-automation/workflow/embedding-helper/requirements.txt @@ -0,0 +1,2 @@ +functions-framework==3.* +google-cloud-spanner From 25a340674468b7011de9e16219df23c327c8dd98 Mon Sep 17 00:00:00 2001 From: shixiao-coder Date: Fri, 17 Apr 2026 14:07:17 -0400 Subject: [PATCH 2/7] Update by comments --- .../embedding-helper/embedding_utils.py | 18 +++++++++--------- .../workflow/embedding-helper/main.py | 4 ++-- 2 files changed, 11 insertions(+), 11 deletions(-) diff --git a/import-automation/workflow/embedding-helper/embedding_utils.py b/import-automation/workflow/embedding-helper/embedding_utils.py index ff506bf078..83231acd42 100644 --- a/import-automation/workflow/embedding-helper/embedding_utils.py +++ b/import-automation/workflow/embedding-helper/embedding_utils.py @@ -15,8 +15,9 @@ """Helper utilities for embedding workflows.""" import logging +import time from datetime import datetime -from google.cloud.spanner_v1.param_types import TIMESTAMP, STRING, Array, Struct, Struct, StructField +from google.cloud.spanner_v1.param_types import TIMESTAMP, STRING, Array, Struct, StructField def get_latest_lock_timestamp(database): @@ -28,10 +29,10 @@ def get_latest_lock_timestamp(database): Returns: The latest AcquiredTimestamp as a datetime object, or None if no entries exist. """ - sql = "SELECT MAX(AcquiredTimestamp) FROM IngestionLock" + time_lock_sql = "SELECT MAX(AcquiredTimestamp) FROM IngestionLock" try: with database.snapshot() as snapshot: - results = snapshot.execute_sql(sql) + results = snapshot.execute_sql(time_lock_sql) for row in results: return row[0] except Exception as e: @@ -52,7 +53,7 @@ def get_updated_nodes(database, timestamp, node_types): """ if not timestamp: logging.info("No timestamp provided, reading all nodes.") - sql = """ + updated_node_sql = """ SELECT subject_id, name FROM Node WHERE name IS NOT NULL AND EXISTS ( @@ -63,7 +64,7 @@ def get_updated_nodes(database, timestamp, node_types): param_types = {"node_types": Array(STRING)} else: logging.info(f"Filtering nodes updated after {timestamp}") - sql = """ + updated_node_sql = """ SELECT subject_id, name FROM Node WHERE update_timestamp > @timestamp AND name IS NOT NULL @@ -77,7 +78,7 @@ def get_updated_nodes(database, timestamp, node_types): nodes = [] try: with database.snapshot() as snapshot: - results = snapshot.execute_sql(sql, params=params, param_types=param_types) + results = snapshot.execute_sql(updated_node_sql, params=params, param_types=param_types) fields = None for row in results: if fields is None: @@ -99,7 +100,6 @@ def filter_and_convert_nodes(nodes): Returns: A list of tuples (subject_id, embedding_content). """ - valid_tuples = [] valid_tuples = [ (node.get("subject_id"), node.get("name")) for node in nodes @@ -128,7 +128,7 @@ def generate_embeddings_partitioned(database, nodes): logging.info(f"Generating embeddings for {len(nodes)} nodes in batches of {BATCH_SIZE}.") - sql = """ + embeddings_sql = """ INSERT OR UPDATE INTO NodeEmbeddings (subject_id, embedding_content, embeddings) SELECT subject_id, content, embeddings.values FROM ML.PREDICT( @@ -149,7 +149,7 @@ def generate_embeddings_partitioned(database, nodes): param_types = {"nodes": Array(struct_type)} def _execute_dml(transaction): - return transaction.execute_update(sql, params=params, param_types=param_types) + return transaction.execute_update(embeddings_sql, params=params, param_types=param_types) try: row_count = database.run_in_transaction(_execute_dml) diff --git a/import-automation/workflow/embedding-helper/main.py b/import-automation/workflow/embedding-helper/main.py index a7f0ccf9ba..24f3f0d451 100644 --- a/import-automation/workflow/embedding-helper/main.py +++ b/import-automation/workflow/embedding-helper/main.py @@ -34,10 +34,10 @@ def embedding_helper(request): action_type = request_json['actionType'] logging.info(f"Received request for actionType: {action_type}") - if action_type == 'initilization': + if action_type == 'initialization': return jsonify({ "status": "success", - "message": "initilization action triggered" + "message": "initialization action triggered" }), 200 elif action_type == 'incremental_update': return jsonify({ From 9f23be8e150259424487324f0541fef6884556f8 Mon Sep 17 00:00:00 2001 From: shixiao-coder Date: Mon, 20 Apr 2026 10:12:20 -0400 Subject: [PATCH 3/7] Modifying the batch size to be divisible by 250. Vertex AI send request on a request of 250. If the batch is smaller than 250 and or not divisible by 250. It actually send more requests to Embeddings models, with each batch containing a much smaller number. Batch from 100 -> 500 changes QPM usage from 1000 to 700 Timeout is set since each request is now containing 250 data and will run longer --- .../workflow/embedding-helper/embedding_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/import-automation/workflow/embedding-helper/embedding_utils.py b/import-automation/workflow/embedding-helper/embedding_utils.py index 83231acd42..d347cfb161 100644 --- a/import-automation/workflow/embedding-helper/embedding_utils.py +++ b/import-automation/workflow/embedding-helper/embedding_utils.py @@ -123,7 +123,7 @@ def generate_embeddings_partitioned(database, nodes): logging.info("No nodes to update.") return 0 - BATCH_SIZE = 100 + BATCH_SIZE = 500 total_rows_affected = 0 logging.info(f"Generating embeddings for {len(nodes)} nodes in batches of {BATCH_SIZE}.") @@ -149,7 +149,7 @@ def generate_embeddings_partitioned(database, nodes): param_types = {"nodes": Array(struct_type)} def _execute_dml(transaction): - return transaction.execute_update(embeddings_sql, params=params, param_types=param_types) + return transaction.execute_update(embeddings_sql, params=params, param_types=param_types, timeout=300) try: row_count = database.run_in_transaction(_execute_dml) From 844b92ef5e5f93c040b383e16ca9ab9f66c1f4f9 Mon Sep 17 00:00:00 2001 From: shixiao-coder Date: Mon, 20 Apr 2026 10:27:27 -0400 Subject: [PATCH 4/7] Reine the logic to use timestamp to filter nodes --- .../embedding-helper/embedding_utils.py | 41 +++++++++---------- 1 file changed, 19 insertions(+), 22 deletions(-) diff --git a/import-automation/workflow/embedding-helper/embedding_utils.py b/import-automation/workflow/embedding-helper/embedding_utils.py index d347cfb161..acf85f29fd 100644 --- a/import-automation/workflow/embedding-helper/embedding_utils.py +++ b/import-automation/workflow/embedding-helper/embedding_utils.py @@ -51,29 +51,26 @@ def get_updated_nodes(database, timestamp, node_types): Returns: A list of dictionaries containing subject_id and name. """ - if not timestamp: - logging.info("No timestamp provided, reading all nodes.") - updated_node_sql = """ - SELECT subject_id, name FROM Node - WHERE name IS NOT NULL - AND EXISTS ( - SELECT 1 FROM UNNEST(types) AS t WHERE t IN UNNEST(@node_types) - ) - """ - params = {"node_types": node_types} - param_types = {"node_types": Array(STRING)} + timestamp_condition = "update_timestamp > @timestamp" if timestamp else "TRUE" + + updated_node_sql = f""" + SELECT subject_id, name FROM Node + WHERE name IS NOT NULL + AND {timestamp_condition} + AND EXISTS ( + SELECT 1 FROM UNNEST(types) AS t WHERE t IN UNNEST(@node_types) + ) + """ + + params = {"node_types": node_types} + param_types = {"node_types": Array(STRING)} + + if timestamp: + logging.info(f"Filtering valid nodes updated after {timestamp}") + params["timestamp"] = timestamp + param_types["timestamp"] = TIMESTAMP else: - logging.info(f"Filtering nodes updated after {timestamp}") - updated_node_sql = """ - SELECT subject_id, name FROM Node - WHERE update_timestamp > @timestamp - AND name IS NOT NULL - AND EXISTS ( - SELECT 1 FROM UNNEST(types) AS t WHERE t IN UNNEST(@node_types) - ) - """ - params = {"timestamp": timestamp, "node_types": node_types} - param_types = {"timestamp": TIMESTAMP, "node_types": Array(STRING)} + logging.info("No timestamp provided, reading all valid nodes.") nodes = [] try: From 470c64b563aea1e382daf41e9493ed8c76620b0f Mon Sep 17 00:00:00 2001 From: shixiao-coder Date: Mon, 20 Apr 2026 17:04:14 -0400 Subject: [PATCH 5/7] Updated to pass data by stream and related Docker to be deployed to cloud run Running with experiment deployed image and confirmed proper ingestion --- .../workflow/embedding-helper/Dockerfile | 10 +++ .../embedding-helper/embedding_utils.py | 50 +++++------ .../workflow/embedding-helper/main.py | 84 ++++++++----------- .../embedding-helper/requirements.txt | 1 + 4 files changed, 71 insertions(+), 74 deletions(-) create mode 100644 import-automation/workflow/embedding-helper/Dockerfile diff --git a/import-automation/workflow/embedding-helper/Dockerfile b/import-automation/workflow/embedding-helper/Dockerfile new file mode 100644 index 0000000000..2f61d779ec --- /dev/null +++ b/import-automation/workflow/embedding-helper/Dockerfile @@ -0,0 +1,10 @@ +FROM python:3.10-slim + +WORKDIR /app + +COPY requirements.txt . +RUN pip install -r requirements.txt + +COPY . . + +CMD ["python", "main.py"] diff --git a/import-automation/workflow/embedding-helper/embedding_utils.py b/import-automation/workflow/embedding-helper/embedding_utils.py index acf85f29fd..fafadab242 100644 --- a/import-automation/workflow/embedding-helper/embedding_utils.py +++ b/import-automation/workflow/embedding-helper/embedding_utils.py @@ -14,6 +14,7 @@ """Helper utilities for embedding workflows.""" +import itertools import logging import time from datetime import datetime @@ -42,14 +43,15 @@ def get_latest_lock_timestamp(database): def get_updated_nodes(database, timestamp, node_types): """Gets subject_ids and names from Node table where update_timestamp > timestamp. + Yields results to avoid loading all into memory. Args: database: google.cloud.spanner.Database object. timestamp: datetime object to filter by. node_types: A list of strings representing the node types to filter by. - Returns: - A list of dictionaries containing subject_id and name. + Yields: + Dictionaries containing subject_id and name. """ timestamp_condition = "update_timestamp > @timestamp" if timestamp else "TRUE" @@ -72,7 +74,6 @@ def get_updated_nodes(database, timestamp, node_types): else: logging.info("No timestamp provided, reading all valid nodes.") - nodes = [] try: with database.snapshot() as snapshot: results = snapshot.execute_sql(updated_node_sql, params=params, param_types=param_types) @@ -80,50 +81,43 @@ def get_updated_nodes(database, timestamp, node_types): for row in results: if fields is None: fields = [field.name for field in results.fields] - nodes.append(dict(zip(fields, row))) + yield dict(zip(fields, row)) except Exception as e: logging.error(f"Error fetching updated nodes: {e}") raise - return nodes -def filter_and_convert_nodes(nodes): +def filter_and_convert_nodes(nodes_generator): """Filters out nodes without a name and converts dictionaries to tuples. - Reads 'name' from input and maps it to output tuple. + Reads from a generator and yields results. Args: - nodes: A list of dictionaries containing subject_id and name. + nodes_generator: A generator yielding dictionaries containing subject_id and name. - Returns: - A list of tuples (subject_id, embedding_content). + Yields: + Tuples (subject_id, embedding_content). """ - valid_tuples = [ - (node.get("subject_id"), node.get("name")) - for node in nodes - if node.get("name") - ] - return valid_tuples + for node in nodes_generator: + if node.get("name"): + yield (node.get("subject_id"), node.get("name")) -def generate_embeddings_partitioned(database, nodes): +def generate_embeddings_partitioned(database, nodes_generator): """Generates embeddings in batches using standard transactions. Processes nodes in chunks of 500 to avoid transaction size limits. + Accepts a generator to avoid loading all nodes into memory. Args: database: google.cloud.spanner.Database object. - nodes: A list of tuples containing (subject_id, embedding_content). + nodes_generator: A generator yielding tuples containing (subject_id, embedding_content). Returns: The number of affected rows. """ - if not nodes: - logging.info("No nodes to update.") - return 0 - BATCH_SIZE = 500 total_rows_affected = 0 - logging.info(f"Generating embeddings for {len(nodes)} nodes in batches of {BATCH_SIZE}.") + logging.info(f"Generating embeddings in batches of {BATCH_SIZE}.") embeddings_sql = """ INSERT OR UPDATE INTO NodeEmbeddings (subject_id, embedding_content, embeddings) @@ -139,9 +133,15 @@ def generate_embeddings_partitioned(database, nodes): StructField("embedding_content", STRING) ]) - for i in range(0, len(nodes), BATCH_SIZE): - batch = nodes[i : i + BATCH_SIZE] + def chunked(iterable, n): + it = iter(iterable) + while True: + chunk = list(itertools.islice(it, n)) + if not chunk: + break + yield chunk + for batch in chunked(nodes_generator, BATCH_SIZE): params = {"nodes": batch} param_types = {"nodes": Array(struct_type)} diff --git a/import-automation/workflow/embedding-helper/main.py b/import-automation/workflow/embedding-helper/main.py index 24f3f0d451..480c449298 100644 --- a/import-automation/workflow/embedding-helper/main.py +++ b/import-automation/workflow/embedding-helper/main.py @@ -1,54 +1,40 @@ -# Copyright 2026 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import functions_framework +import os import logging -from flask import jsonify - -logging.getLogger().setLevel(logging.INFO) +from google.cloud import spanner +from embedding_utils import get_updated_nodes, filter_and_convert_nodes, generate_embeddings_partitioned -@functions_framework.http -def embedding_helper(request): - """ - HTTP Cloud Function for handling embedding tasks. - Takes request and argument actionType. - """ - request_json = request.get_json(silent=True) - if not request_json: - return ('Request is not a valid JSON', 400) +logging.basicConfig(level=logging.INFO) - if 'actionType' not in request_json: - return ("'actionType' parameter is missing", 400) +def main(): + # Read configuration from environment variables + instance_id = os.environ.get("SPANNER_INSTANCE") + database_id = os.environ.get("SPANNER_DATABASE") + project_id = os.environ.get("SPANNER_PROJECT") + + if not instance_id or not database_id: + logging.error("SPANNER_INSTANCE or SPANNER_DATABASE environment variables not set.") + exit(1) + + logging.info(f"Connecting to Spanner instance: {instance_id}, database: {database_id}, project: {project_id}") + + spanner_client = spanner.Client(project=project_id) + instance = spanner_client.instance(instance_id) + database = instance.database(database_id) - action_type = request_json['actionType'] - logging.info(f"Received request for actionType: {action_type}") + node_types = ["StatisticalVariable", "Topic"] + + try: + logging.info(f"Job started. Fetching all nodes for types: {node_types}") + nodes = get_updated_nodes(database, None, node_types) + + converted_nodes = filter_and_convert_nodes(nodes) + + affected_rows = generate_embeddings_partitioned(database, converted_nodes) + + logging.info(f"Job completed successfully. Total affected rows: {affected_rows}") + except Exception as e: + logging.error(f"Job failed with error: {e}") + exit(1) - if action_type == 'initialization': - return jsonify({ - "status": "success", - "message": "initialization action triggered" - }), 200 - elif action_type == 'incremental_update': - return jsonify({ - "status": "success", - "message": "incremental_update action triggered" - }), 200 - elif action_type == 'manual_update': - return jsonify({ - "status": "success", - "message": "manual_update action triggered" - }), 200 - else: - logging.warning(f"Unknown actionType: {action_type}") - return (f"Unknown actionType: {action_type}", 400) +if __name__ == "__main__": + main() diff --git a/import-automation/workflow/embedding-helper/requirements.txt b/import-automation/workflow/embedding-helper/requirements.txt index 754e2bd573..58cfe74bd9 100644 --- a/import-automation/workflow/embedding-helper/requirements.txt +++ b/import-automation/workflow/embedding-helper/requirements.txt @@ -1,2 +1,3 @@ functions-framework==3.* google-cloud-spanner +google-auth From 2a4a745fee6b848cb6e4e0e1095b73c5f247452e Mon Sep 17 00:00:00 2001 From: shixiao-coder Date: Tue, 21 Apr 2026 16:04:44 -0400 Subject: [PATCH 6/7] Update the NodeEmbeddings table to contain the types. Types will be used for filtering Nodes --- .../embedding-helper/embedding_utils.py | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/import-automation/workflow/embedding-helper/embedding_utils.py b/import-automation/workflow/embedding-helper/embedding_utils.py index fafadab242..59cdfb0e4c 100644 --- a/import-automation/workflow/embedding-helper/embedding_utils.py +++ b/import-automation/workflow/embedding-helper/embedding_utils.py @@ -56,7 +56,7 @@ def get_updated_nodes(database, timestamp, node_types): timestamp_condition = "update_timestamp > @timestamp" if timestamp else "TRUE" updated_node_sql = f""" - SELECT subject_id, name FROM Node + SELECT subject_id, name, types FROM Node WHERE name IS NOT NULL AND {timestamp_condition} AND EXISTS ( @@ -92,14 +92,14 @@ def filter_and_convert_nodes(nodes_generator): Reads from a generator and yields results. Args: - nodes_generator: A generator yielding dictionaries containing subject_id and name. + nodes_generator: A generator yielding dictionaries containing subject_id, name, and types. Yields: - Tuples (subject_id, embedding_content). + Tuples (subject_id, embedding_content, types). """ for node in nodes_generator: if node.get("name"): - yield (node.get("subject_id"), node.get("name")) + yield (node.get("subject_id"), node.get("name"), node.get("types")) def generate_embeddings_partitioned(database, nodes_generator): @@ -120,17 +120,18 @@ def generate_embeddings_partitioned(database, nodes_generator): logging.info(f"Generating embeddings in batches of {BATCH_SIZE}.") embeddings_sql = """ - INSERT OR UPDATE INTO NodeEmbeddings (subject_id, embedding_content, embeddings) - SELECT subject_id, content, embeddings.values + INSERT OR UPDATE INTO NodeEmbeddings (subject_id, embedding_content, embeddings, types) + SELECT subject_id, content, embeddings.values, types FROM ML.PREDICT( MODEL text_embeddings, - (SELECT subject_id, embedding_content AS content, "RETRIEVAL_QUERY" AS task_type FROM UNNEST(@nodes)) + (SELECT subject_id, embedding_content AS content, types, "RETRIEVAL_QUERY" AS task_type FROM UNNEST(@nodes)) ) """ struct_type = Struct([ StructField("subject_id", STRING), - StructField("embedding_content", STRING) + StructField("embedding_content", STRING), + StructField("types", Array(STRING)) ]) def chunked(iterable, n): From 11465f25d1102d4b6e9237b977250f28e0b046c9 Mon Sep 17 00:00:00 2001 From: shixiao-coder Date: Wed, 22 Apr 2026 10:54:14 -0400 Subject: [PATCH 7/7] Add tests for all embedding util functions as well E2E for main. Updating header message for code --- .../embedding-helper/embedding_utils.py | 8 +- .../workflow/embedding-helper/main.py | 19 ++- .../test/embedding_utils_test.py | 111 ++++++++++++++++++ .../embedding-helper/test/main_test.py | 61 ++++++++++ 4 files changed, 194 insertions(+), 5 deletions(-) create mode 100644 import-automation/workflow/embedding-helper/test/embedding_utils_test.py create mode 100644 import-automation/workflow/embedding-helper/test/main_test.py diff --git a/import-automation/workflow/embedding-helper/embedding_utils.py b/import-automation/workflow/embedding-helper/embedding_utils.py index 59cdfb0e4c..ce7b4762c5 100644 --- a/import-automation/workflow/embedding-helper/embedding_utils.py +++ b/import-automation/workflow/embedding-helper/embedding_utils.py @@ -21,6 +21,8 @@ from google.cloud.spanner_v1.param_types import TIMESTAMP, STRING, Array, Struct, StructField +_BATCH_SIZE = 500 + def get_latest_lock_timestamp(database): """Gets the latest AcquiredTimestamp from IngestionLock table. @@ -114,10 +116,10 @@ def generate_embeddings_partitioned(database, nodes_generator): Returns: The number of affected rows. """ - BATCH_SIZE = 500 + global _BATCH_SIZE total_rows_affected = 0 - logging.info(f"Generating embeddings in batches of {BATCH_SIZE}.") + logging.info(f"Generating embeddings in batches of {_BATCH_SIZE}.") embeddings_sql = """ INSERT OR UPDATE INTO NodeEmbeddings (subject_id, embedding_content, embeddings, types) @@ -142,7 +144,7 @@ def chunked(iterable, n): break yield chunk - for batch in chunked(nodes_generator, BATCH_SIZE): + for batch in chunked(nodes_generator, _BATCH_SIZE): params = {"nodes": batch} param_types = {"nodes": Array(struct_type)} diff --git a/import-automation/workflow/embedding-helper/main.py b/import-automation/workflow/embedding-helper/main.py index 480c449298..a0247027b8 100644 --- a/import-automation/workflow/embedding-helper/main.py +++ b/import-automation/workflow/embedding-helper/main.py @@ -1,7 +1,21 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os import logging from google.cloud import spanner -from embedding_utils import get_updated_nodes, filter_and_convert_nodes, generate_embeddings_partitioned +from embedding_utils import get_latest_lock_timestamp, get_updated_nodes, filter_and_convert_nodes, generate_embeddings_partitioned logging.basicConfig(level=logging.INFO) @@ -25,7 +39,8 @@ def main(): try: logging.info(f"Job started. Fetching all nodes for types: {node_types}") - nodes = get_updated_nodes(database, None, node_types) + timestamp = get_latest_lock_timestamp(database) + nodes = get_updated_nodes(database, timestamp, node_types) converted_nodes = filter_and_convert_nodes(nodes) diff --git a/import-automation/workflow/embedding-helper/test/embedding_utils_test.py b/import-automation/workflow/embedding-helper/test/embedding_utils_test.py new file mode 100644 index 0000000000..9e261ea627 --- /dev/null +++ b/import-automation/workflow/embedding-helper/test/embedding_utils_test.py @@ -0,0 +1,111 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import MagicMock, patch +from datetime import datetime +import sys +import os + +# Add parent directory of current file (src directory) to the path +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from embedding_utils import ( + get_latest_lock_timestamp, + get_updated_nodes, + filter_and_convert_nodes, + generate_embeddings_partitioned +) + +class TestEmbeddingUtils(unittest.TestCase): + + def test_get_latest_lock_timestamp(self): + mock_database = MagicMock() + mock_snapshot = MagicMock() + mock_database.snapshot.return_value.__enter__.return_value = mock_snapshot + expected_timestamp = datetime(2026, 4, 20, 12, 0, 0) + mock_snapshot.execute_sql.return_value = [(expected_timestamp,)] + + timestamp = get_latest_lock_timestamp(mock_database) + self.assertEqual(timestamp, expected_timestamp) + + def test_get_updated_nodes(self): + mock_database = MagicMock() + mock_snapshot = MagicMock() + mock_database.snapshot.return_value.__enter__.return_value = mock_snapshot + + class MockField: + def __init__(self, name): + self.name = name + + class MockResults: + def __init__(self, rows, field_names): + self.rows = rows + self.fields = [MockField(name) for name in field_names] + + def __iter__(self): + return iter(self.rows) + + mock_snapshot.execute_sql.return_value = MockResults( + rows=[("dc/1", "Node 1", ["Topic"])], + field_names=["subject_id", "name", "types"] + ) + + nodes = list(get_updated_nodes(mock_database, None, ["Topic"])) + self.assertEqual(len(nodes), 1) + self.assertEqual(nodes[0]["subject_id"], "dc/1") + self.assertEqual(nodes[0]["name"], "Node 1") + self.assertEqual(nodes[0]["types"], ["Topic"]) + + def test_filter_and_convert_nodes(self): + nodes = [ + {"subject_id": "dc/1", "name": "Node 1", "types": ["Topic"]}, + {"subject_id": "dc/2", "name": None, "types": ["StatisticalVariable"]}, + {"subject_id": "dc/3", "name": "Node 3", "types": ["Topic", "StatisticalVariable"]}, + {"subject_id": "dc/4", "name": "", "types": ["StatisticalVariable"]} + ] + + converted = list(filter_and_convert_nodes(nodes)) + self.assertEqual(len(converted), 2) + self.assertEqual(converted[0], ("dc/1", "Node 1", ["Topic"])) + self.assertEqual(converted[1], ("dc/3", "Node 3", ["Topic", "StatisticalVariable"])) + + @patch('embedding_utils._BATCH_SIZE', 2) + def test_generate_embeddings_partitioned(self): + mock_database = MagicMock() + + nodes = [ + ("dc/1", "Node 1", ["Topic"]), + ("dc/2", "Node 2", ["Topic"]), + ("dc/3", "Node 3", ["Topic"]), + ("dc/4", "Node 4", ["Topic"]), + ("dc/5", "Node 5", ["Topic"]), + ("dc/6", "Node 6", ["Topic"]), + ("dc/7", "Node 7", ["Topic"]), + ("dc/8", "Node 8", ["Topic"]) + ] + + def side_effect(func): + mock_transaction = MagicMock() + mock_transaction.execute_update.return_value = 2 + return func(mock_transaction) + + mock_database.run_in_transaction.side_effect = side_effect + + affected_rows = generate_embeddings_partitioned(mock_database, nodes) + self.assertEqual(affected_rows, 8) + self.assertEqual(mock_database.run_in_transaction.call_count, 4) + +if __name__ == '__main__': + unittest.main() diff --git a/import-automation/workflow/embedding-helper/test/main_test.py b/import-automation/workflow/embedding-helper/test/main_test.py new file mode 100644 index 0000000000..6c3f0b3b94 --- /dev/null +++ b/import-automation/workflow/embedding-helper/test/main_test.py @@ -0,0 +1,61 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +from unittest.mock import MagicMock, patch +import sys +import os + +from datetime import datetime +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import main + +class TestMain(unittest.TestCase): + + @patch.dict(os.environ, { + "SPANNER_INSTANCE": "test-instance", + "SPANNER_DATABASE": "test-db", + "SPANNER_PROJECT": "test-proj" + }) + @patch('main.spanner.Client') + @patch('main.get_latest_lock_timestamp') + @patch('main.get_updated_nodes') + @patch('main.filter_and_convert_nodes') + @patch('main.generate_embeddings_partitioned') + def test_main_e2e_success(self, mock_generate, mock_filter, mock_nodes, mock_timestamp, mock_spanner): + mock_database = MagicMock() + mock_instance = MagicMock() + mock_instance.database.return_value = mock_database + mock_spanner.return_value.instance.return_value = mock_instance + + timestamp_val = datetime(2026, 4, 20, 12, 0, 0) + mock_timestamp.return_value = timestamp_val + mock_nodes.return_value = [{"subject_id": "dc/1", "name": "Node 1", "types": ["Topic"]}] + mock_filter.return_value = [("dc/1", "Node 1", ["Topic"])] + mock_generate.return_value = 1 + + try: + main.main() + except SystemExit as e: + self.assertEqual(e.code, 0) + + mock_spanner.assert_called_once_with(project="test-proj") + mock_timestamp.assert_called_once_with(mock_database) + mock_nodes.assert_called_once_with(mock_database, timestamp_val, ["StatisticalVariable", "Topic"]) + mock_filter.assert_called_once() + mock_generate.assert_called_once() + +if __name__ == '__main__': + unittest.main()