Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 26 additions & 36 deletions src/memos/graph_dbs/polardb.py
Original file line number Diff line number Diff line change
Expand Up @@ -1800,7 +1800,6 @@ def search_by_fulltext(
FROM "{self.db_name}_graph"."Memory" m
CROSS JOIN q
{where_clause_cte}
ORDER BY rank DESC
LIMIT {top_k};
"""
params = [tsquery_string]
Expand Down Expand Up @@ -2411,56 +2410,47 @@ def _extract_special_filter_values(filter_obj):
order_clause = """
ORDER BY ag_catalog.agtype_access_operator(properties, '"created_at"'::agtype) DESC NULLS LAST,id DESC
"""
count_query = f"""
SELECT COUNT(*) AS total_count
FROM "{self.db_name}_graph"."Memory"
{where_clause}
"""
if include_embedding:
node_query = f"""
WITH filtered AS (
SELECT id, properties, embedding
FROM "{self.db_name}_graph"."Memory"
{where_clause}
)
SELECT p.id, p.properties, p.embedding, c.total_count
FROM (SELECT COUNT(*) AS total_count FROM filtered) c
LEFT JOIN LATERAL (
SELECT id, properties, embedding
FROM filtered
{order_clause}
{pagination_clause}
) p ON TRUE
data_query = f"""
SELECT id, properties, embedding
FROM "{self.db_name}_graph"."Memory"
{where_clause}
{order_clause}
{pagination_clause}
"""
else:
node_query = f"""
WITH filtered AS (
SELECT id, properties
FROM "{self.db_name}_graph"."Memory"
{where_clause}
)
SELECT p.id, p.properties, c.total_count
FROM (SELECT COUNT(*) AS total_count FROM filtered) c
LEFT JOIN LATERAL (
SELECT id, properties
FROM filtered
{order_clause}
{pagination_clause}
) p ON TRUE
data_query = f"""
SELECT id, properties
FROM "{self.db_name}_graph"."Memory"
{where_clause}
{order_clause}
{pagination_clause}
"""
logger.info(f"[export_graph nodes] Query: {node_query}")
logger.info(f"[export_graph nodes] count_query: {count_query}")
logger.info(f"[export_graph nodes] data_query: {data_query}")

try:
with self._get_connection() as conn, conn.cursor() as cursor:
cursor.execute(node_query)
cursor.execute(count_query)
count_row = cursor.fetchone()
total_nodes = int(count_row[0]) if count_row and count_row[0] is not None else 0

cursor.execute(data_query)
node_results = cursor.fetchall()
nodes = []

for row in node_results:
if include_embedding:
row_id, properties_json, embedding_json, row_total_count = row
row_id, properties_json, embedding_json = row
else:
row_id, properties_json, row_total_count = row
row_id, properties_json = row
embedding_json = None

if row_total_count is not None:
total_nodes = int(row_total_count)

if row_id is None:
continue

Expand Down
Loading