From 95b4b92685702114c1f52581e02f1819050b5ceb Mon Sep 17 00:00:00 2001 From: Honglei-Qiu <1044497581@qq.com> Date: Mon, 23 Mar 2026 08:46:54 +0000 Subject: [PATCH 01/10] group by v2 --- sqlite/graph_net_sample_groups_insert2.py | 227 ++++++++++++++++++++++ 1 file changed, 227 insertions(+) create mode 100755 sqlite/graph_net_sample_groups_insert2.py diff --git a/sqlite/graph_net_sample_groups_insert2.py b/sqlite/graph_net_sample_groups_insert2.py new file mode 100755 index 000000000..d8c767f66 --- /dev/null +++ b/sqlite/graph_net_sample_groups_insert2.py @@ -0,0 +1,227 @@ +""" +bucket_policy_v2: Generate graph_net_sample_groups for v2 candidates. + +v2 candidates = total_graph_buckets - v1_selected_graph_buckets + +Graph bucket key: (op_seq_bucket_id, input_shapes_bucket_id, input_dtypes_bucket_id, graph_hash) + +v2 grouping strategy (progressive, mutually exclusive with v1): + Rule 3 (global sparse sampling): per op_seq, sort by sample_uid, every 5 pick 1. + Rule 4 (dtype coverage): from Rule 3 remainder, per op_seq+shape, pick num_dtypes different-dtype samples. +""" + +import argparse +import sqlite3 +import uuid as uuid_module +from datetime import datetime +from collections import namedtuple, defaultdict + +from orm_models import ( + get_session, + GraphNetSampleGroup, +) + + +class DB: + def __init__(self, path): + self.path = path + + def connect(self): + self.conn = sqlite3.connect(self.path) + self.conn.row_factory = sqlite3.Row + self.cur = self.conn.cursor() + + def query(self, sql, params=None): + self.cur.execute(sql, params or ()) + return self.cur.fetchall() + + def close(self): + self.conn.close() + + +CandidateGraph = namedtuple( + "CandidateGraph", + [ + "sample_uid", + "op_seq_bucket_id", + "input_shapes_bucket_id", + "input_dtypes_bucket_id", + "graph_hash", + ], +) + + +def query_v2_candidates(db: DB) -> list[CandidateGraph]: + """ + Query v2 candidate graphs: + total graph buckets - v1 selected graph buckets. + + Each (op_seq, shapes, dtypes, graph_hash) bucket picks one representative sample_uid + (the earliest by create_at, uuid). + + v1 selected = sample_uids already in graph_net_sample_groups with group_policy='bucket_policy_v1'. + """ + query_str = """ +SELECT + sub.sample_uid, + sub.op_seq_bucket_id, + sub.input_shapes_bucket_id, + sub.input_dtypes_bucket_id, + sub.graph_hash +FROM ( + SELECT + s.uuid AS sample_uid, + b.op_seq_bucket_id, + b.input_shapes_bucket_id, + b.input_dtypes_bucket_id, + s.graph_hash, + ROW_NUMBER() OVER ( + PARTITION BY b.op_seq_bucket_id, b.input_shapes_bucket_id, b.input_dtypes_bucket_id, s.graph_hash + ORDER BY s.create_at ASC, s.uuid ASC + ) AS rn + FROM graph_sample s + JOIN graph_net_sample_buckets b ON s.uuid = b.sample_uid + WHERE s.deleted = 0 + AND s.sample_type != 'full_graph' +) sub +WHERE sub.rn = 1 + AND sub.sample_uid NOT IN ( + SELECT g.sample_uid + FROM graph_net_sample_groups g + WHERE g.group_policy = 'bucket_policy_v1' + AND g.deleted = 0 + ) +ORDER BY sub.op_seq_bucket_id, sub.input_shapes_bucket_id, sub.input_dtypes_bucket_id, sub.sample_uid; + """ + rows = db.query(query_str) + return [CandidateGraph(*row) for row in rows] + + +def get_v2_group_members(candidates: list[CandidateGraph], num_dtypes: int): + """ + Yield (sample_uid, group_uid) pairs for v2 grouping. + + Rule 3 (global sparse sampling): + Sort all candidates by sample_uid, every 5 pick 1 sample. + All picked samples under the same op_seq share one group_uid. + + Rule 4 (dtype coverage): + From remaining candidates (not selected in Rule 3), + per op_seq, for each shape, pick up to num_dtypes samples with different dtypes. + All picked samples under the same op_seq share one group_uid. + """ + # Index candidates by op_seq + by_op_seq = defaultdict(list) + for c in candidates: + by_op_seq[c.op_seq_bucket_id].append(c) + + rule3_selected_uids = set() + + # --- Rule 3: global sparse sampling --- + # Window size = num_dtypes * stride(5) = 15, pick first num_dtypes(3) per window + window_size = num_dtypes * 5 + for op_seq, op_candidates in by_op_seq.items(): + sorted_candidates = sorted(op_candidates, key=lambda c: c.sample_uid) + + rule3_uids = [] + for order_value, c in enumerate(sorted_candidates): + if (order_value % window_size) < num_dtypes: + rule3_uids.append(c.sample_uid) + rule3_selected_uids.add(c.sample_uid) + + if rule3_uids: + group_uid = str(uuid_module.uuid4()) + for uid in rule3_uids: + yield uid, group_uid + + # --- Rule 4: dtype coverage --- + for op_seq, op_candidates in by_op_seq.items(): + remaining = [ + c for c in op_candidates if c.sample_uid not in rule3_selected_uids + ] + + # Sub-group by shape + by_shape = defaultdict(list) + for c in remaining: + by_shape[c.input_shapes_bucket_id].append(c) + + rule4_uids = [] + for shape, shape_candidates in by_shape.items(): + # Pick up to num_dtypes samples with different dtypes + seen_dtypes = set() + for c in shape_candidates: + if ( + c.input_dtypes_bucket_id not in seen_dtypes + and len(seen_dtypes) < num_dtypes + ): + seen_dtypes.add(c.input_dtypes_bucket_id) + rule4_uids.append(c.sample_uid) + + if rule4_uids: + group_uid = str(uuid_module.uuid4()) + for uid in rule4_uids: + yield uid, group_uid + + +def main(): + parser = argparse.ArgumentParser( + description="Generate graph_net_sample_groups with bucket_policy_v2" + ) + parser.add_argument( + "--db_path", + type=str, + required=True, + help="Path to the SQLite database file", + ) + parser.add_argument( + "--num_dtypes", + type=int, + default=3, + help="Number of different dtypes to pick per shape (default: 3)", + ) + + args = parser.parse_args() + db = DB(args.db_path) + db.connect() + + print("Step 1: Querying v2 candidates...") + candidates = query_v2_candidates(db) + print(f" v2 candidate graphs: {len(candidates)}") + db.close() + + if not candidates: + print("No v2 candidates found. Done!") + return + + print(f"Step 2: Generating v2 groups (num_dtypes={args.num_dtypes})...") + session = get_session(args.db_path) + + try: + count = 0 + for sample_uid, group_uid in get_v2_group_members(candidates, args.num_dtypes): + new_group = GraphNetSampleGroup( + sample_uid=sample_uid, + group_uid=group_uid, + group_type="ai4c", + group_policy="bucket_policy_v2", + policy_version="1.0", + create_at=datetime.now(), + deleted=False, + ) + session.add(new_group) + count += 1 + + session.commit() + print(f" Inserted {count} group records.") + + except Exception: + session.rollback() + raise + finally: + session.close() + + print("Done!") + + +if __name__ == "__main__": + main() From 6185419800dc78a4f63933ae469aeb03c24a10e4 Mon Sep 17 00:00:00 2001 From: Honglei-Qiu <1044497581@qq.com> Date: Mon, 23 Mar 2026 08:52:10 +0000 Subject: [PATCH 02/10] group by v2 --- sqlite/graph_net_sample_groups_insert2.py | 33 ----------------------- 1 file changed, 33 deletions(-) diff --git a/sqlite/graph_net_sample_groups_insert2.py b/sqlite/graph_net_sample_groups_insert2.py index d8c767f66..bb2ad04bd 100755 --- a/sqlite/graph_net_sample_groups_insert2.py +++ b/sqlite/graph_net_sample_groups_insert2.py @@ -1,15 +1,3 @@ -""" -bucket_policy_v2: Generate graph_net_sample_groups for v2 candidates. - -v2 candidates = total_graph_buckets - v1_selected_graph_buckets - -Graph bucket key: (op_seq_bucket_id, input_shapes_bucket_id, input_dtypes_bucket_id, graph_hash) - -v2 grouping strategy (progressive, mutually exclusive with v1): - Rule 3 (global sparse sampling): per op_seq, sort by sample_uid, every 5 pick 1. - Rule 4 (dtype coverage): from Rule 3 remainder, per op_seq+shape, pick num_dtypes different-dtype samples. -""" - import argparse import sqlite3 import uuid as uuid_module @@ -52,15 +40,6 @@ def close(self): def query_v2_candidates(db: DB) -> list[CandidateGraph]: - """ - Query v2 candidate graphs: - total graph buckets - v1 selected graph buckets. - - Each (op_seq, shapes, dtypes, graph_hash) bucket picks one representative sample_uid - (the earliest by create_at, uuid). - - v1 selected = sample_uids already in graph_net_sample_groups with group_policy='bucket_policy_v1'. - """ query_str = """ SELECT sub.sample_uid, @@ -98,18 +77,6 @@ def query_v2_candidates(db: DB) -> list[CandidateGraph]: def get_v2_group_members(candidates: list[CandidateGraph], num_dtypes: int): - """ - Yield (sample_uid, group_uid) pairs for v2 grouping. - - Rule 3 (global sparse sampling): - Sort all candidates by sample_uid, every 5 pick 1 sample. - All picked samples under the same op_seq share one group_uid. - - Rule 4 (dtype coverage): - From remaining candidates (not selected in Rule 3), - per op_seq, for each shape, pick up to num_dtypes samples with different dtypes. - All picked samples under the same op_seq share one group_uid. - """ # Index candidates by op_seq by_op_seq = defaultdict(list) for c in candidates: From 91a73899e982b88a690900d29c459b83070574ac Mon Sep 17 00:00:00 2001 From: Honglei-Qiu <1044497581@qq.com> Date: Mon, 23 Mar 2026 13:26:50 +0000 Subject: [PATCH 03/10] group by v2 --- sqlite/graph_net_sample_groups_insert2.py | 194 ---------------------- 1 file changed, 194 deletions(-) delete mode 100755 sqlite/graph_net_sample_groups_insert2.py diff --git a/sqlite/graph_net_sample_groups_insert2.py b/sqlite/graph_net_sample_groups_insert2.py deleted file mode 100755 index bb2ad04bd..000000000 --- a/sqlite/graph_net_sample_groups_insert2.py +++ /dev/null @@ -1,194 +0,0 @@ -import argparse -import sqlite3 -import uuid as uuid_module -from datetime import datetime -from collections import namedtuple, defaultdict - -from orm_models import ( - get_session, - GraphNetSampleGroup, -) - - -class DB: - def __init__(self, path): - self.path = path - - def connect(self): - self.conn = sqlite3.connect(self.path) - self.conn.row_factory = sqlite3.Row - self.cur = self.conn.cursor() - - def query(self, sql, params=None): - self.cur.execute(sql, params or ()) - return self.cur.fetchall() - - def close(self): - self.conn.close() - - -CandidateGraph = namedtuple( - "CandidateGraph", - [ - "sample_uid", - "op_seq_bucket_id", - "input_shapes_bucket_id", - "input_dtypes_bucket_id", - "graph_hash", - ], -) - - -def query_v2_candidates(db: DB) -> list[CandidateGraph]: - query_str = """ -SELECT - sub.sample_uid, - sub.op_seq_bucket_id, - sub.input_shapes_bucket_id, - sub.input_dtypes_bucket_id, - sub.graph_hash -FROM ( - SELECT - s.uuid AS sample_uid, - b.op_seq_bucket_id, - b.input_shapes_bucket_id, - b.input_dtypes_bucket_id, - s.graph_hash, - ROW_NUMBER() OVER ( - PARTITION BY b.op_seq_bucket_id, b.input_shapes_bucket_id, b.input_dtypes_bucket_id, s.graph_hash - ORDER BY s.create_at ASC, s.uuid ASC - ) AS rn - FROM graph_sample s - JOIN graph_net_sample_buckets b ON s.uuid = b.sample_uid - WHERE s.deleted = 0 - AND s.sample_type != 'full_graph' -) sub -WHERE sub.rn = 1 - AND sub.sample_uid NOT IN ( - SELECT g.sample_uid - FROM graph_net_sample_groups g - WHERE g.group_policy = 'bucket_policy_v1' - AND g.deleted = 0 - ) -ORDER BY sub.op_seq_bucket_id, sub.input_shapes_bucket_id, sub.input_dtypes_bucket_id, sub.sample_uid; - """ - rows = db.query(query_str) - return [CandidateGraph(*row) for row in rows] - - -def get_v2_group_members(candidates: list[CandidateGraph], num_dtypes: int): - # Index candidates by op_seq - by_op_seq = defaultdict(list) - for c in candidates: - by_op_seq[c.op_seq_bucket_id].append(c) - - rule3_selected_uids = set() - - # --- Rule 3: global sparse sampling --- - # Window size = num_dtypes * stride(5) = 15, pick first num_dtypes(3) per window - window_size = num_dtypes * 5 - for op_seq, op_candidates in by_op_seq.items(): - sorted_candidates = sorted(op_candidates, key=lambda c: c.sample_uid) - - rule3_uids = [] - for order_value, c in enumerate(sorted_candidates): - if (order_value % window_size) < num_dtypes: - rule3_uids.append(c.sample_uid) - rule3_selected_uids.add(c.sample_uid) - - if rule3_uids: - group_uid = str(uuid_module.uuid4()) - for uid in rule3_uids: - yield uid, group_uid - - # --- Rule 4: dtype coverage --- - for op_seq, op_candidates in by_op_seq.items(): - remaining = [ - c for c in op_candidates if c.sample_uid not in rule3_selected_uids - ] - - # Sub-group by shape - by_shape = defaultdict(list) - for c in remaining: - by_shape[c.input_shapes_bucket_id].append(c) - - rule4_uids = [] - for shape, shape_candidates in by_shape.items(): - # Pick up to num_dtypes samples with different dtypes - seen_dtypes = set() - for c in shape_candidates: - if ( - c.input_dtypes_bucket_id not in seen_dtypes - and len(seen_dtypes) < num_dtypes - ): - seen_dtypes.add(c.input_dtypes_bucket_id) - rule4_uids.append(c.sample_uid) - - if rule4_uids: - group_uid = str(uuid_module.uuid4()) - for uid in rule4_uids: - yield uid, group_uid - - -def main(): - parser = argparse.ArgumentParser( - description="Generate graph_net_sample_groups with bucket_policy_v2" - ) - parser.add_argument( - "--db_path", - type=str, - required=True, - help="Path to the SQLite database file", - ) - parser.add_argument( - "--num_dtypes", - type=int, - default=3, - help="Number of different dtypes to pick per shape (default: 3)", - ) - - args = parser.parse_args() - db = DB(args.db_path) - db.connect() - - print("Step 1: Querying v2 candidates...") - candidates = query_v2_candidates(db) - print(f" v2 candidate graphs: {len(candidates)}") - db.close() - - if not candidates: - print("No v2 candidates found. Done!") - return - - print(f"Step 2: Generating v2 groups (num_dtypes={args.num_dtypes})...") - session = get_session(args.db_path) - - try: - count = 0 - for sample_uid, group_uid in get_v2_group_members(candidates, args.num_dtypes): - new_group = GraphNetSampleGroup( - sample_uid=sample_uid, - group_uid=group_uid, - group_type="ai4c", - group_policy="bucket_policy_v2", - policy_version="1.0", - create_at=datetime.now(), - deleted=False, - ) - session.add(new_group) - count += 1 - - session.commit() - print(f" Inserted {count} group records.") - - except Exception: - session.rollback() - raise - finally: - session.close() - - print("Done!") - - -if __name__ == "__main__": - main() From e444c2e800049013b7d6d3d5d822b3dd0f274663 Mon Sep 17 00:00:00 2001 From: Honglei-Qiu <1044497581@qq.com> Date: Mon, 23 Mar 2026 13:32:07 +0000 Subject: [PATCH 04/10] group by v2 --- sqlite/graph_net_sample_groups_insert_v2.py | 238 ++++++++++++++++++++ 1 file changed, 238 insertions(+) create mode 100755 sqlite/graph_net_sample_groups_insert_v2.py diff --git a/sqlite/graph_net_sample_groups_insert_v2.py b/sqlite/graph_net_sample_groups_insert_v2.py new file mode 100755 index 000000000..9091ff015 --- /dev/null +++ b/sqlite/graph_net_sample_groups_insert_v2.py @@ -0,0 +1,238 @@ +import argparse +import sqlite3 +import uuid as uuid_module +from datetime import datetime +from collections import namedtuple, defaultdict + +from orm_models import ( + get_session, + GraphNetSampleGroup, +) + + +class DB: + def __init__(self, path): + self.path = path + + def connect(self): + self.conn = sqlite3.connect(self.path) + self.conn.row_factory = sqlite3.Row + self.cur = self.conn.cursor() + + def query(self, sql, params=None): + self.cur.execute(sql, params or ()) + return self.cur.fetchall() + + def close(self): + self.conn.close() + + +CandidateGraph = namedtuple( + "CandidateGraph", + [ + "sample_uid", + "op_seq_bucket_id", + "input_shapes_bucket_id", + "input_dtypes_bucket_id", + "graph_hash", + ], +) + + +def load_v1_paths( + v1_list_files: list[str], sample_types: list[str] +) -> set[tuple[str, str]]: + """ + Load (sample_type, relative_model_path) pairs from v1 list files. + File format: uuid\\trelative_path (tab separated, one per line) + + Each file corresponds to a sample_type, passed via sample_types in the same order. + """ + pairs = set() + for filepath, sample_type in zip(v1_list_files, sample_types): + with open(filepath, "r") as f: + for line in f: + line = line.strip() + if not line: + continue + parts = line.split("\t", 1) + if len(parts) == 2: + pairs.add((sample_type, parts[1])) + return pairs + + +def query_v2_candidates(db: DB, v1_pairs: set[tuple[str, str]]) -> list[CandidateGraph]: + query_str = """ +SELECT + sub.sample_uid, + sub.op_seq_bucket_id, + sub.input_shapes_bucket_id, + sub.input_dtypes_bucket_id, + sub.graph_hash, + sub.relative_model_path, + sub.sample_type +FROM ( + SELECT + s.uuid AS sample_uid, + s.relative_model_path, + s.sample_type, + b.op_seq_bucket_id, + b.input_shapes_bucket_id, + b.input_dtypes_bucket_id, + s.graph_hash, + ROW_NUMBER() OVER ( + PARTITION BY b.op_seq_bucket_id, b.input_shapes_bucket_id, b.input_dtypes_bucket_id, s.graph_hash + ORDER BY s.create_at ASC, s.uuid ASC + ) AS rn + FROM graph_sample s + JOIN graph_net_sample_buckets b ON s.uuid = b.sample_uid + WHERE s.deleted = 0 + AND s.sample_type != 'full_graph' +) sub +WHERE sub.rn = 1 +ORDER BY sub.op_seq_bucket_id, sub.input_shapes_bucket_id, sub.input_dtypes_bucket_id, sub.sample_uid; + """ + rows = db.query(query_str) + # Filter out v1 selected (sample_type, path) pairs in Python + candidates = [] + for row in rows: + sample_uid, op_seq, shapes, dtypes, graph_hash, rel_path, sample_type = row + if (sample_type, rel_path) not in v1_pairs: + candidates.append( + CandidateGraph(sample_uid, op_seq, shapes, dtypes, graph_hash) + ) + return candidates + + +def get_v2_group_members(candidates: list[CandidateGraph], num_dtypes: int): + # Index candidates by op_seq + by_op_seq = defaultdict(list) + for c in candidates: + by_op_seq[c.op_seq_bucket_id].append(c) + + rule3_selected_uids = set() + + # --- Rule 3: global sparse sampling --- + # Window size = num_dtypes * stride(5) = 15, pick first num_dtypes(3) per window + window_size = num_dtypes * 5 + for op_seq, op_candidates in by_op_seq.items(): + sorted_candidates = sorted(op_candidates, key=lambda c: c.sample_uid) + + rule3_uids = [] + for order_value, c in enumerate(sorted_candidates): + if (order_value % window_size) < num_dtypes: + rule3_uids.append(c.sample_uid) + rule3_selected_uids.add(c.sample_uid) + + if rule3_uids: + group_uid = str(uuid_module.uuid4()) + for uid in rule3_uids: + yield uid, group_uid + + # --- Rule 4: dtype coverage --- + for op_seq, op_candidates in by_op_seq.items(): + remaining = [ + c for c in op_candidates if c.sample_uid not in rule3_selected_uids + ] + + # Sub-group by shape + by_shape = defaultdict(list) + for c in remaining: + by_shape[c.input_shapes_bucket_id].append(c) + + rule4_uids = [] + for shape, shape_candidates in by_shape.items(): + # Pick up to num_dtypes samples with different dtypes + seen_dtypes = set() + for c in shape_candidates: + if ( + c.input_dtypes_bucket_id not in seen_dtypes + and len(seen_dtypes) < num_dtypes + ): + seen_dtypes.add(c.input_dtypes_bucket_id) + rule4_uids.append(c.sample_uid) + + if rule4_uids: + group_uid = str(uuid_module.uuid4()) + for uid in rule4_uids: + yield uid, group_uid + + +def main(): + parser = argparse.ArgumentParser( + description="Generate graph_net_sample_groups with bucket_policy_v2" + ) + parser.add_argument( + "--db_path", + type=str, + required=True, + help="Path to the SQLite database file", + ) + parser.add_argument( + "--num_dtypes", + type=int, + default=3, + help="Number of different dtypes to pick per shape (default: 3)", + ) + parser.add_argument( + "--v1_list_files", + nargs="+", + required=True, + help="Path(s) to v1 list files (uuid\\trelative_path format) to exclude", + ) + parser.add_argument( + "--v1_sample_types", + nargs="+", + required=True, + help="Sample type for each v1 list file, in the same order (e.g., fusible_graph typical_graph sole_op_graph)", + ) + + args = parser.parse_args() + db = DB(args.db_path) + db.connect() + + print("Step 1: Loading v1 paths to exclude...") + v1_pairs = load_v1_paths(args.v1_list_files, args.v1_sample_types) + print(f" v1 (sample_type, path) pairs loaded: {len(v1_pairs)}") + + print("Step 2: Querying v2 candidates...") + candidates = query_v2_candidates(db, v1_pairs) + print(f" v2 candidate graphs: {len(candidates)}") + db.close() + + if not candidates: + print("No v2 candidates found. Done!") + return + + print(f"Step 3: Generating v2 groups (num_dtypes={args.num_dtypes})...") + session = get_session(args.db_path) + + try: + count = 0 + for sample_uid, group_uid in get_v2_group_members(candidates, args.num_dtypes): + new_group = GraphNetSampleGroup( + sample_uid=sample_uid, + group_uid=group_uid, + group_type="ai4c", + group_policy="bucket_policy_v2", + policy_version="1.0", + create_at=datetime.now(), + deleted=False, + ) + session.add(new_group) + count += 1 + + session.commit() + print(f" Inserted {count} group records.") + + except Exception: + session.rollback() + raise + finally: + session.close() + + print("Done!") + + +if __name__ == "__main__": + main() From a13a3564e620731c3e7074d910fac410b7cd27f8 Mon Sep 17 00:00:00 2001 From: Honglei-Qiu <1044497581@qq.com> Date: Tue, 24 Mar 2026 03:12:10 +0000 Subject: [PATCH 05/10] group by all --- sqlite/graph_net_sample_groups_insert.py | 278 ++++++++++++++++---- sqlite/graph_net_sample_groups_insert_v2.py | 238 ----------------- 2 files changed, 234 insertions(+), 282 deletions(-) delete mode 100755 sqlite/graph_net_sample_groups_insert_v2.py diff --git a/sqlite/graph_net_sample_groups_insert.py b/sqlite/graph_net_sample_groups_insert.py index 03c71498e..d21ed3611 100755 --- a/sqlite/graph_net_sample_groups_insert.py +++ b/sqlite/graph_net_sample_groups_insert.py @@ -2,8 +2,7 @@ import sqlite3 import uuid as uuid_module from datetime import datetime -from collections import namedtuple -from collections import defaultdict +from collections import namedtuple, defaultdict from orm_models import ( get_session, @@ -11,11 +10,6 @@ ) -GraphNetSampleUid = str -GraphNetSampleType = str -BucketId = str - - class DB: def __init__(self, path): self.path = path @@ -29,14 +23,12 @@ def query(self, sql, params=None): self.cur.execute(sql, params or ()) return self.cur.fetchall() - def exec(self, sql, params=None): - self.cur.execute(sql, params or ()) - self.conn.commit() - def close(self): self.conn.close() +# ── V1 types ── + SampleBucketInfo = namedtuple( "SampleBucketInfo", [ @@ -50,7 +42,30 @@ def close(self): ) -def get_ai4c_group_members(sample_bucket_infos: list[SampleBucketInfo]): +# ── V2 types ── + +CandidateGraph = namedtuple( + "CandidateGraph", + [ + "sample_uid", + "op_seq_bucket_id", + "input_shapes_bucket_id", + "input_dtypes_bucket_id", + "graph_hash", + ], +) + + +# ═══════════════════════════════════════════════════════════════════ +# V1: Rule 1 (bucket-internal stride-5) + Rule 2 (cross-shape aggregation) +# ═══════════════════════════════════════════════════════════════════ + + +def get_v1_group_members(sample_bucket_infos: list[SampleBucketInfo]): + """Rule 1: stride-5 sampling within each bucket, each sample is its own group. + Rule 2: cross-shape aggregation by (op_seq, dtype), heads share one group_uid.""" + + # Rule 1 for bucket_info in sample_bucket_infos: head_sample_uid = bucket_info.sample_uid sample_uids = bucket_info.sample_uids.split(",") @@ -63,6 +78,7 @@ def get_ai4c_group_members(sample_bucket_infos: list[SampleBucketInfo]): new_uuid = str(uuid_module.uuid4()) yield sample_uid, new_uuid + # Rule 2 grouped = defaultdict(list) for bucket_info in sample_bucket_infos: key = (bucket_info.op_seq_bucket_id, bucket_info.input_dtypes_bucket_id) @@ -75,21 +91,7 @@ def get_ai4c_group_members(sample_bucket_infos: list[SampleBucketInfo]): yield sample_uid, new_uuid -def main(): - parser = argparse.ArgumentParser( - description="Generate graph_net_sample_groups from graph_net_sample_buckets" - ) - parser.add_argument( - "--db_path", - type=str, - required=True, - help="Path to the SQLite database file", - ) - - args = parser.parse_args() - db = DB(args.db_path) - db.connect() - +def query_v1_candidates(db: DB): query_str = """ SELECT b.sample_uid, b.op_seq_bucket_id as op_seq, b.input_shapes_bucket_id, b.input_dtypes_bucket_id, b.sample_type, group_concat(b.sample_uid, ',') as sample_uids FROM ( @@ -104,36 +106,224 @@ def main(): ON s.uuid = b.sample_uid order by s.create_at asc, s.uuid asc ) b -GROUP BY b.sample_type, b.op_seq_bucket_id, b.input_shapes_bucket_id,b.input_dtypes_bucket_id; +GROUP BY b.sample_type, b.op_seq_bucket_id, b.input_shapes_bucket_id, b.input_dtypes_bucket_id; + """ + rows = db.query(query_str) + return [SampleBucketInfo(*row) for row in rows] + + +def insert_v1(db: DB, session, db_path: str): + print("=" * 60) + print("V1: Rule 1 (stride-5) + Rule 2 (cross-shape aggregation)") + print("=" * 60) + + candidates = query_v1_candidates(db) + print(f" Bucket groups: {len(candidates)}") + + count = 0 + for sample_uid, group_uid in get_v1_group_members(candidates): + new_group = GraphNetSampleGroup( + sample_uid=sample_uid, + group_uid=group_uid, + group_type="ai4c", + group_policy="bucket_policy_v1", + policy_version="1.0", + create_at=datetime.now(), + deleted=False, + ) + session.add(new_group) + count += 1 + + session.commit() + print(f" Inserted {count} v1 group records.") + return count + + +# ═══════════════════════════════════════════════════════════════════ +# V2: Rule 3 (global sparse sampling) + Rule 4 (dtype coverage) +# ═══════════════════════════════════════════════════════════════════ + + +def get_v2_group_members(candidates: list[CandidateGraph], num_dtypes: int): + """Rule 3: global sparse sampling, window_size=num_dtypes*5, pick first num_dtypes per window. + Rule 4: dtype coverage, per (op_seq, shape) pick up to num_dtypes different-dtype samples. """ - query_results = db.query(query_str) - print("Output:", len(query_results)) + # Index candidates by op_seq + by_op_seq = defaultdict(list) + for c in candidates: + by_op_seq[c.op_seq_bucket_id].append(c) + + rule3_selected_uids = set() + + # --- Rule 3: global sparse sampling --- + window_size = num_dtypes * 5 + for op_seq, op_candidates in by_op_seq.items(): + sorted_candidates = sorted(op_candidates, key=lambda c: c.sample_uid) + + rule3_uids = [] + for order_value, c in enumerate(sorted_candidates): + if (order_value % window_size) < num_dtypes: + rule3_uids.append(c.sample_uid) + rule3_selected_uids.add(c.sample_uid) - query_results = [SampleBucketInfo(*row) for row in query_results] + if rule3_uids: + group_uid = str(uuid_module.uuid4()) + for uid in rule3_uids: + yield uid, group_uid + + # --- Rule 4: dtype coverage --- + for op_seq, op_candidates in by_op_seq.items(): + remaining = [ + c for c in op_candidates if c.sample_uid not in rule3_selected_uids + ] + + # Sub-group by shape + by_shape = defaultdict(list) + for c in remaining: + by_shape[c.input_shapes_bucket_id].append(c) + + rule4_uids = [] + for shape, shape_candidates in by_shape.items(): + seen_dtypes = set() + for c in shape_candidates: + if ( + c.input_dtypes_bucket_id not in seen_dtypes + and len(seen_dtypes) < num_dtypes + ): + seen_dtypes.add(c.input_dtypes_bucket_id) + rule4_uids.append(c.sample_uid) + + if rule4_uids: + group_uid = str(uuid_module.uuid4()) + for uid in rule4_uids: + yield uid, group_uid + + +def query_v2_candidates(db: DB) -> list[CandidateGraph]: + query_str = """ +SELECT + sub.sample_uid, + sub.op_seq_bucket_id, + sub.input_shapes_bucket_id, + sub.input_dtypes_bucket_id, + sub.graph_hash +FROM ( + SELECT + s.uuid AS sample_uid, + b.op_seq_bucket_id, + b.input_shapes_bucket_id, + b.input_dtypes_bucket_id, + s.graph_hash, + ROW_NUMBER() OVER ( + PARTITION BY b.op_seq_bucket_id, b.input_shapes_bucket_id, b.input_dtypes_bucket_id, s.graph_hash + ORDER BY s.create_at ASC, s.uuid ASC + ) AS rn + FROM graph_sample s + JOIN graph_net_sample_buckets b ON s.uuid = b.sample_uid + WHERE s.deleted = 0 + AND s.sample_type != 'full_graph' +) sub +WHERE sub.rn = 1 + AND sub.sample_uid NOT IN ( + SELECT g.sample_uid + FROM graph_net_sample_groups g + WHERE g.group_policy = 'bucket_policy_v1' + AND g.deleted = 0 + ) +ORDER BY sub.op_seq_bucket_id, sub.input_shapes_bucket_id, sub.input_dtypes_bucket_id, sub.sample_uid; + """ + rows = db.query(query_str) + return [CandidateGraph(*row) for row in rows] + + +def insert_v2(db: DB, session, db_path: str, num_dtypes: int): + print("=" * 60) + print("V2: Rule 3 (sparse sampling) + Rule 4 (dtype coverage)") + print("=" * 60) + + candidates = query_v2_candidates(db) + print(f" V2 candidate graphs: {len(candidates)}") + + if not candidates: + print(" No v2 candidates found. Skipping.") + return 0 + + count = 0 + for sample_uid, group_uid in get_v2_group_members(candidates, num_dtypes): + new_group = GraphNetSampleGroup( + sample_uid=sample_uid, + group_uid=group_uid, + group_type="ai4c", + group_policy="bucket_policy_v2", + policy_version="1.0", + create_at=datetime.now(), + deleted=False, + ) + session.add(new_group) + count += 1 + + session.commit() + print(f" Inserted {count} v2 group records.") + return count + + +# ═══════════════════════════════════════════════════════════════════ +# Main +# ═══════════════════════════════════════════════════════════════════ + + +def main(): + parser = argparse.ArgumentParser( + description="Generate graph_net_sample_groups (v1 + v2)" + ) + parser.add_argument( + "--db_path", + type=str, + required=True, + help="Path to the SQLite database file", + ) + parser.add_argument( + "--num_dtypes", + type=int, + default=3, + help="Number of different dtypes to pick per shape in v2 (default: 3)", + ) + parser.add_argument( + "--only_v1", + action="store_true", + help="Only run v1 (Rule 1 + Rule 2)", + ) + parser.add_argument( + "--only_v2", + action="store_true", + help="Only run v2 (Rule 3 + Rule 4), requires v1 already in DB", + ) + + args = parser.parse_args() + db = DB(args.db_path) + db.connect() session = get_session(args.db_path) try: - for sample_uid, group_uid in get_ai4c_group_members(query_results): - new_group = GraphNetSampleGroup( - sample_uid=sample_uid, - group_uid=group_uid, - group_type="ai4c", - group_policy="bucket_policy_v1", - policy_version="1.0", - create_at=datetime.now(), - deleted=False, - ) - - session.add(new_group) - session.commit() + run_v1 = not args.only_v2 + run_v2 = not args.only_v1 + + if run_v1: + insert_v1(db, session, args.db_path) + + if run_v2: + insert_v2(db, session, args.db_path, args.num_dtypes) except Exception: session.rollback() raise finally: session.close() + db.close() + + print("\nDone!") if __name__ == "__main__": diff --git a/sqlite/graph_net_sample_groups_insert_v2.py b/sqlite/graph_net_sample_groups_insert_v2.py deleted file mode 100755 index 9091ff015..000000000 --- a/sqlite/graph_net_sample_groups_insert_v2.py +++ /dev/null @@ -1,238 +0,0 @@ -import argparse -import sqlite3 -import uuid as uuid_module -from datetime import datetime -from collections import namedtuple, defaultdict - -from orm_models import ( - get_session, - GraphNetSampleGroup, -) - - -class DB: - def __init__(self, path): - self.path = path - - def connect(self): - self.conn = sqlite3.connect(self.path) - self.conn.row_factory = sqlite3.Row - self.cur = self.conn.cursor() - - def query(self, sql, params=None): - self.cur.execute(sql, params or ()) - return self.cur.fetchall() - - def close(self): - self.conn.close() - - -CandidateGraph = namedtuple( - "CandidateGraph", - [ - "sample_uid", - "op_seq_bucket_id", - "input_shapes_bucket_id", - "input_dtypes_bucket_id", - "graph_hash", - ], -) - - -def load_v1_paths( - v1_list_files: list[str], sample_types: list[str] -) -> set[tuple[str, str]]: - """ - Load (sample_type, relative_model_path) pairs from v1 list files. - File format: uuid\\trelative_path (tab separated, one per line) - - Each file corresponds to a sample_type, passed via sample_types in the same order. - """ - pairs = set() - for filepath, sample_type in zip(v1_list_files, sample_types): - with open(filepath, "r") as f: - for line in f: - line = line.strip() - if not line: - continue - parts = line.split("\t", 1) - if len(parts) == 2: - pairs.add((sample_type, parts[1])) - return pairs - - -def query_v2_candidates(db: DB, v1_pairs: set[tuple[str, str]]) -> list[CandidateGraph]: - query_str = """ -SELECT - sub.sample_uid, - sub.op_seq_bucket_id, - sub.input_shapes_bucket_id, - sub.input_dtypes_bucket_id, - sub.graph_hash, - sub.relative_model_path, - sub.sample_type -FROM ( - SELECT - s.uuid AS sample_uid, - s.relative_model_path, - s.sample_type, - b.op_seq_bucket_id, - b.input_shapes_bucket_id, - b.input_dtypes_bucket_id, - s.graph_hash, - ROW_NUMBER() OVER ( - PARTITION BY b.op_seq_bucket_id, b.input_shapes_bucket_id, b.input_dtypes_bucket_id, s.graph_hash - ORDER BY s.create_at ASC, s.uuid ASC - ) AS rn - FROM graph_sample s - JOIN graph_net_sample_buckets b ON s.uuid = b.sample_uid - WHERE s.deleted = 0 - AND s.sample_type != 'full_graph' -) sub -WHERE sub.rn = 1 -ORDER BY sub.op_seq_bucket_id, sub.input_shapes_bucket_id, sub.input_dtypes_bucket_id, sub.sample_uid; - """ - rows = db.query(query_str) - # Filter out v1 selected (sample_type, path) pairs in Python - candidates = [] - for row in rows: - sample_uid, op_seq, shapes, dtypes, graph_hash, rel_path, sample_type = row - if (sample_type, rel_path) not in v1_pairs: - candidates.append( - CandidateGraph(sample_uid, op_seq, shapes, dtypes, graph_hash) - ) - return candidates - - -def get_v2_group_members(candidates: list[CandidateGraph], num_dtypes: int): - # Index candidates by op_seq - by_op_seq = defaultdict(list) - for c in candidates: - by_op_seq[c.op_seq_bucket_id].append(c) - - rule3_selected_uids = set() - - # --- Rule 3: global sparse sampling --- - # Window size = num_dtypes * stride(5) = 15, pick first num_dtypes(3) per window - window_size = num_dtypes * 5 - for op_seq, op_candidates in by_op_seq.items(): - sorted_candidates = sorted(op_candidates, key=lambda c: c.sample_uid) - - rule3_uids = [] - for order_value, c in enumerate(sorted_candidates): - if (order_value % window_size) < num_dtypes: - rule3_uids.append(c.sample_uid) - rule3_selected_uids.add(c.sample_uid) - - if rule3_uids: - group_uid = str(uuid_module.uuid4()) - for uid in rule3_uids: - yield uid, group_uid - - # --- Rule 4: dtype coverage --- - for op_seq, op_candidates in by_op_seq.items(): - remaining = [ - c for c in op_candidates if c.sample_uid not in rule3_selected_uids - ] - - # Sub-group by shape - by_shape = defaultdict(list) - for c in remaining: - by_shape[c.input_shapes_bucket_id].append(c) - - rule4_uids = [] - for shape, shape_candidates in by_shape.items(): - # Pick up to num_dtypes samples with different dtypes - seen_dtypes = set() - for c in shape_candidates: - if ( - c.input_dtypes_bucket_id not in seen_dtypes - and len(seen_dtypes) < num_dtypes - ): - seen_dtypes.add(c.input_dtypes_bucket_id) - rule4_uids.append(c.sample_uid) - - if rule4_uids: - group_uid = str(uuid_module.uuid4()) - for uid in rule4_uids: - yield uid, group_uid - - -def main(): - parser = argparse.ArgumentParser( - description="Generate graph_net_sample_groups with bucket_policy_v2" - ) - parser.add_argument( - "--db_path", - type=str, - required=True, - help="Path to the SQLite database file", - ) - parser.add_argument( - "--num_dtypes", - type=int, - default=3, - help="Number of different dtypes to pick per shape (default: 3)", - ) - parser.add_argument( - "--v1_list_files", - nargs="+", - required=True, - help="Path(s) to v1 list files (uuid\\trelative_path format) to exclude", - ) - parser.add_argument( - "--v1_sample_types", - nargs="+", - required=True, - help="Sample type for each v1 list file, in the same order (e.g., fusible_graph typical_graph sole_op_graph)", - ) - - args = parser.parse_args() - db = DB(args.db_path) - db.connect() - - print("Step 1: Loading v1 paths to exclude...") - v1_pairs = load_v1_paths(args.v1_list_files, args.v1_sample_types) - print(f" v1 (sample_type, path) pairs loaded: {len(v1_pairs)}") - - print("Step 2: Querying v2 candidates...") - candidates = query_v2_candidates(db, v1_pairs) - print(f" v2 candidate graphs: {len(candidates)}") - db.close() - - if not candidates: - print("No v2 candidates found. Done!") - return - - print(f"Step 3: Generating v2 groups (num_dtypes={args.num_dtypes})...") - session = get_session(args.db_path) - - try: - count = 0 - for sample_uid, group_uid in get_v2_group_members(candidates, args.num_dtypes): - new_group = GraphNetSampleGroup( - sample_uid=sample_uid, - group_uid=group_uid, - group_type="ai4c", - group_policy="bucket_policy_v2", - policy_version="1.0", - create_at=datetime.now(), - deleted=False, - ) - session.add(new_group) - count += 1 - - session.commit() - print(f" Inserted {count} group records.") - - except Exception: - session.rollback() - raise - finally: - session.close() - - print("Done!") - - -if __name__ == "__main__": - main() From b4a098a8ed51daf55b21bb113b7b2826a7831005 Mon Sep 17 00:00:00 2001 From: Honglei-Qiu <1044497581@qq.com> Date: Tue, 24 Mar 2026 03:39:54 +0000 Subject: [PATCH 06/10] group by all --- sqlite/graph_net_sample_groups_insert.py | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/sqlite/graph_net_sample_groups_insert.py b/sqlite/graph_net_sample_groups_insert.py index d21ed3611..ae010ab05 100755 --- a/sqlite/graph_net_sample_groups_insert.py +++ b/sqlite/graph_net_sample_groups_insert.py @@ -35,7 +35,6 @@ def close(self): "sample_uid", "op_seq_bucket_id", "input_shapes_bucket_id", - "input_dtypes_bucket_id", "sample_type", "sample_uids", ], @@ -63,7 +62,7 @@ def close(self): def get_v1_group_members(sample_bucket_infos: list[SampleBucketInfo]): """Rule 1: stride-5 sampling within each bucket, each sample is its own group. - Rule 2: cross-shape aggregation by (op_seq, dtype), heads share one group_uid.""" + Rule 2: cross-shape aggregation by op_seq, heads share one group_uid.""" # Rule 1 for bucket_info in sample_bucket_infos: @@ -81,11 +80,10 @@ def get_v1_group_members(sample_bucket_infos: list[SampleBucketInfo]): # Rule 2 grouped = defaultdict(list) for bucket_info in sample_bucket_infos: - key = (bucket_info.op_seq_bucket_id, bucket_info.input_dtypes_bucket_id) - grouped[key].append(bucket_info.sample_uid) + grouped[bucket_info.op_seq_bucket_id].append(bucket_info.sample_uid) grouped = dict(grouped) - for key, sample_uids in grouped.items(): + for op_seq, sample_uids in grouped.items(): new_uuid = str(uuid_module.uuid4()) for sample_uid in sample_uids: yield sample_uid, new_uuid @@ -93,20 +91,19 @@ def get_v1_group_members(sample_bucket_infos: list[SampleBucketInfo]): def query_v1_candidates(db: DB): query_str = """ -SELECT b.sample_uid, b.op_seq_bucket_id as op_seq, b.input_shapes_bucket_id, b.input_dtypes_bucket_id, b.sample_type, group_concat(b.sample_uid, ',') as sample_uids +SELECT b.sample_uid, b.op_seq_bucket_id as op_seq, b.input_shapes_bucket_id, b.sample_type, group_concat(b.sample_uid, ',') as sample_uids FROM ( SELECT s.uuid AS sample_uid, s.sample_type AS sample_type, b.op_seq_bucket_id AS op_seq_bucket_id, - b.input_shapes_bucket_id AS input_shapes_bucket_id, - b.input_dtypes_bucket_id AS input_dtypes_bucket_id + b.input_shapes_bucket_id AS input_shapes_bucket_id FROM graph_sample s JOIN graph_net_sample_buckets b ON s.uuid = b.sample_uid order by s.create_at asc, s.uuid asc ) b -GROUP BY b.sample_type, b.op_seq_bucket_id, b.input_shapes_bucket_id, b.input_dtypes_bucket_id; +GROUP BY b.sample_type, b.op_seq_bucket_id, b.input_shapes_bucket_id; """ rows = db.query(query_str) return [SampleBucketInfo(*row) for row in rows] From 238d96653eda544f15d4d7fe3e6769683e3b1f20 Mon Sep 17 00:00:00 2001 From: Honglei-Qiu <1044497581@qq.com> Date: Tue, 24 Mar 2026 07:05:25 +0000 Subject: [PATCH 07/10] group by all --- sqlite/graph_net_sample_groups_insert.py | 354 ++++++++++------------- 1 file changed, 158 insertions(+), 196 deletions(-) diff --git a/sqlite/graph_net_sample_groups_insert.py b/sqlite/graph_net_sample_groups_insert.py index ae010ab05..587fb56ec 100755 --- a/sqlite/graph_net_sample_groups_insert.py +++ b/sqlite/graph_net_sample_groups_insert.py @@ -29,240 +29,220 @@ def close(self): # ── V1 types ── -SampleBucketInfo = namedtuple( - "SampleBucketInfo", - [ - "sample_uid", - "op_seq_bucket_id", - "input_shapes_bucket_id", - "sample_type", - "sample_uids", - ], +BucketGroup = namedtuple( + "BucketGroup", + ["head_uid", "op_seq", "shapes", "sample_type", "all_uids_csv"], ) - # ── V2 types ── -CandidateGraph = namedtuple( - "CandidateGraph", - [ - "sample_uid", - "op_seq_bucket_id", - "input_shapes_bucket_id", - "input_dtypes_bucket_id", - "graph_hash", - ], +V2Candidate = namedtuple( + "V2Candidate", + ["uid", "op_seq", "shapes", "dtypes"], ) +def _new_group_id(): + return str(uuid_module.uuid4()) + + +def _print_stats(stats, rule_order): + total_records = 0 + total_groups = 0 + for rule_name in rule_order: + if rule_name in stats: + r = stats[rule_name]["records"] + g = len(stats[rule_name]["groups"]) + print(f" {rule_name}: {r} records, {g} groups") + total_records += r + total_groups += g + print(f" Total: {total_records} records, {total_groups} groups.") + + # ═══════════════════════════════════════════════════════════════════ # V1: Rule 1 (bucket-internal stride-5) + Rule 2 (cross-shape aggregation) # ═══════════════════════════════════════════════════════════════════ -def get_v1_group_members(sample_bucket_infos: list[SampleBucketInfo]): - """Rule 1: stride-5 sampling within each bucket, each sample is its own group. - Rule 2: cross-shape aggregation by op_seq, heads share one group_uid.""" - - # Rule 1 - for bucket_info in sample_bucket_infos: - head_sample_uid = bucket_info.sample_uid - sample_uids = bucket_info.sample_uids.split(",") - selected_other_sample_uids = [ - other_sample_uid - for other_sample_uid in sample_uids[::5] - if other_sample_uid != head_sample_uid - ] - for sample_uid in selected_other_sample_uids: - new_uuid = str(uuid_module.uuid4()) - yield sample_uid, new_uuid - - # Rule 2 - grouped = defaultdict(list) - for bucket_info in sample_bucket_infos: - grouped[bucket_info.op_seq_bucket_id].append(bucket_info.sample_uid) - - grouped = dict(grouped) - for op_seq, sample_uids in grouped.items(): - new_uuid = str(uuid_module.uuid4()) - for sample_uid in sample_uids: - yield sample_uid, new_uuid - - -def query_v1_candidates(db: DB): - query_str = """ -SELECT b.sample_uid, b.op_seq_bucket_id as op_seq, b.input_shapes_bucket_id, b.sample_type, group_concat(b.sample_uid, ',') as sample_uids +def generate_v1_groups(bucket_groups: list[BucketGroup]): + """Rule 1: stride-5 sampling, each sampled uid gets its own group. + Rule 2: group all bucket heads that share the same op_seq. + Yields (uid, group_id, rule_name).""" + + # Rule 1: stride-5 sampling + for bucket in bucket_groups: + members = bucket.all_uids_csv.split(",") + sampled = [uid for uid in members[::5] if uid != bucket.head_uid] + for uid in sampled: + yield uid, _new_group_id(), "rule1" + + # Rule 2: group heads by op_seq + op_seq_to_heads = defaultdict(list) + for bucket in bucket_groups: + op_seq_to_heads[bucket.op_seq].append(bucket.head_uid) + + for heads in op_seq_to_heads.values(): + group_id = _new_group_id() + for uid in heads: + yield uid, group_id, "rule2" + + +def query_v1_bucket_groups(db: DB) -> list[BucketGroup]: + sql = """ +SELECT + sub.sample_uid, + sub.op_seq_bucket_id, + sub.input_shapes_bucket_id, + sub.sample_type, + group_concat(sub.sample_uid, ',') AS all_uids FROM ( SELECT - s.uuid AS sample_uid, - s.sample_type AS sample_type, - b.op_seq_bucket_id AS op_seq_bucket_id, - b.input_shapes_bucket_id AS input_shapes_bucket_id + s.uuid AS sample_uid, + s.sample_type, + b.op_seq_bucket_id, + b.input_shapes_bucket_id FROM graph_sample s - JOIN graph_net_sample_buckets b - ON s.uuid = b.sample_uid - order by s.create_at asc, s.uuid asc -) b -GROUP BY b.sample_type, b.op_seq_bucket_id, b.input_shapes_bucket_id; + JOIN graph_net_sample_buckets b ON s.uuid = b.sample_uid + ORDER BY s.create_at ASC, s.uuid ASC +) sub +GROUP BY sub.sample_type, sub.op_seq_bucket_id, sub.input_shapes_bucket_id; """ - rows = db.query(query_str) - return [SampleBucketInfo(*row) for row in rows] + return [BucketGroup(*row) for row in db.query(sql)] -def insert_v1(db: DB, session, db_path: str): +def insert_v1_groups(db: DB, session): print("=" * 60) print("V1: Rule 1 (stride-5) + Rule 2 (cross-shape aggregation)") print("=" * 60) - candidates = query_v1_candidates(db) - print(f" Bucket groups: {len(candidates)}") - - count = 0 - for sample_uid, group_uid in get_v1_group_members(candidates): - new_group = GraphNetSampleGroup( - sample_uid=sample_uid, - group_uid=group_uid, - group_type="ai4c", - group_policy="bucket_policy_v1", - policy_version="1.0", - create_at=datetime.now(), - deleted=False, + bucket_groups = query_v1_bucket_groups(db) + print(f" Bucket groups: {len(bucket_groups)}") + + stats = defaultdict(lambda: {"records": 0, "groups": set()}) + for uid, group_id, rule_name in generate_v1_groups(bucket_groups): + session.add( + GraphNetSampleGroup( + sample_uid=uid, + group_uid=group_id, + group_type="ai4c", + group_policy="bucket_policy_v1", + policy_version="1.0", + create_at=datetime.now(), + deleted=False, + ) ) - session.add(new_group) - count += 1 + stats[rule_name]["records"] += 1 + stats[rule_name]["groups"].add(group_id) session.commit() - print(f" Inserted {count} v1 group records.") - return count + _print_stats(stats, ["rule1", "rule2"]) # ═══════════════════════════════════════════════════════════════════ -# V2: Rule 3 (global sparse sampling) + Rule 4 (dtype coverage) +# V2: Rule 4 (dtype coverage) + Rule 3 (global sparse sampling) # ═══════════════════════════════════════════════════════════════════ -def get_v2_group_members(candidates: list[CandidateGraph], num_dtypes: int): - """Rule 3: global sparse sampling, window_size=num_dtypes*5, pick first num_dtypes per window. - Rule 4: dtype coverage, per (op_seq, shape) pick up to num_dtypes different-dtype samples. - """ +def generate_v2_groups(candidates: list[V2Candidate], num_dtypes: int): + """Rule 4 runs first to ensure full dtype coverage. + Rule 3 then sparse-samples from the remaining candidates. + Yields (uid, group_id, rule_name).""" - # Index candidates by op_seq - by_op_seq = defaultdict(list) + candidates_by_op_seq = defaultdict(list) for c in candidates: - by_op_seq[c.op_seq_bucket_id].append(c) + candidates_by_op_seq[c.op_seq].append(c) - rule3_selected_uids = set() + dtype_covered_uids = set() - # --- Rule 3: global sparse sampling --- - window_size = num_dtypes * 5 - for op_seq, op_candidates in by_op_seq.items(): - sorted_candidates = sorted(op_candidates, key=lambda c: c.sample_uid) - - rule3_uids = [] - for order_value, c in enumerate(sorted_candidates): - if (order_value % window_size) < num_dtypes: - rule3_uids.append(c.sample_uid) - rule3_selected_uids.add(c.sample_uid) - - if rule3_uids: - group_uid = str(uuid_module.uuid4()) - for uid in rule3_uids: - yield uid, group_uid - - # --- Rule 4: dtype coverage --- - for op_seq, op_candidates in by_op_seq.items(): - remaining = [ - c for c in op_candidates if c.sample_uid not in rule3_selected_uids - ] - - # Sub-group by shape - by_shape = defaultdict(list) - for c in remaining: - by_shape[c.input_shapes_bucket_id].append(c) - - rule4_uids = [] - for shape, shape_candidates in by_shape.items(): + # --- Rule 4: dtype coverage (runs first) --- + for op_seq, op_candidates in candidates_by_op_seq.items(): + candidates_by_shape = defaultdict(list) + for c in op_candidates: + candidates_by_shape[c.shapes].append(c) + + selected_uids = [] + for shape, shape_candidates in candidates_by_shape.items(): seen_dtypes = set() for c in shape_candidates: - if ( - c.input_dtypes_bucket_id not in seen_dtypes - and len(seen_dtypes) < num_dtypes - ): - seen_dtypes.add(c.input_dtypes_bucket_id) - rule4_uids.append(c.sample_uid) + if c.dtypes not in seen_dtypes and len(seen_dtypes) < num_dtypes: + seen_dtypes.add(c.dtypes) + selected_uids.append(c.uid) + dtype_covered_uids.add(c.uid) - if rule4_uids: - group_uid = str(uuid_module.uuid4()) - for uid in rule4_uids: - yield uid, group_uid + if selected_uids: + group_id = _new_group_id() + for uid in selected_uids: + yield uid, group_id, "rule4" + + # --- Rule 3: global sparse sampling (on remaining candidates) --- + window_size = num_dtypes * 5 + for op_seq, op_candidates in candidates_by_op_seq.items(): + remaining = [c for c in op_candidates if c.uid not in dtype_covered_uids] + remaining.sort(key=lambda c: c.uid) + selected_uids = [] + for idx, c in enumerate(remaining): + if (idx % window_size) < num_dtypes: + selected_uids.append(c.uid) -def query_v2_candidates(db: DB) -> list[CandidateGraph]: - query_str = """ + if selected_uids: + group_id = _new_group_id() + for uid in selected_uids: + yield uid, group_id, "rule3" + + +def query_v2_candidates(db: DB) -> list[V2Candidate]: + sql = """ SELECT - sub.sample_uid, - sub.op_seq_bucket_id, - sub.input_shapes_bucket_id, - sub.input_dtypes_bucket_id, - sub.graph_hash -FROM ( - SELECT - s.uuid AS sample_uid, - b.op_seq_bucket_id, - b.input_shapes_bucket_id, - b.input_dtypes_bucket_id, - s.graph_hash, - ROW_NUMBER() OVER ( - PARTITION BY b.op_seq_bucket_id, b.input_shapes_bucket_id, b.input_dtypes_bucket_id, s.graph_hash - ORDER BY s.create_at ASC, s.uuid ASC - ) AS rn - FROM graph_sample s - JOIN graph_net_sample_buckets b ON s.uuid = b.sample_uid - WHERE s.deleted = 0 - AND s.sample_type != 'full_graph' -) sub -WHERE sub.rn = 1 - AND sub.sample_uid NOT IN ( + s.uuid, + b.op_seq_bucket_id, + b.input_shapes_bucket_id, + b.input_dtypes_bucket_id +FROM graph_sample s +JOIN graph_net_sample_buckets b ON s.uuid = b.sample_uid +WHERE s.deleted = 0 + AND s.sample_type != 'full_graph' + AND s.uuid NOT IN ( SELECT g.sample_uid FROM graph_net_sample_groups g WHERE g.group_policy = 'bucket_policy_v1' AND g.deleted = 0 ) -ORDER BY sub.op_seq_bucket_id, sub.input_shapes_bucket_id, sub.input_dtypes_bucket_id, sub.sample_uid; +ORDER BY b.op_seq_bucket_id, b.input_shapes_bucket_id, b.input_dtypes_bucket_id, s.uuid; """ - rows = db.query(query_str) - return [CandidateGraph(*row) for row in rows] + return [V2Candidate(*row) for row in db.query(sql)] -def insert_v2(db: DB, session, db_path: str, num_dtypes: int): +def insert_v2_groups(db: DB, session, num_dtypes: int): print("=" * 60) - print("V2: Rule 3 (sparse sampling) + Rule 4 (dtype coverage)") + print("V2: Rule 4 (dtype coverage) + Rule 3 (sparse sampling)") print("=" * 60) candidates = query_v2_candidates(db) - print(f" V2 candidate graphs: {len(candidates)}") + print(f" V2 candidates: {len(candidates)}") if not candidates: print(" No v2 candidates found. Skipping.") - return 0 - - count = 0 - for sample_uid, group_uid in get_v2_group_members(candidates, num_dtypes): - new_group = GraphNetSampleGroup( - sample_uid=sample_uid, - group_uid=group_uid, - group_type="ai4c", - group_policy="bucket_policy_v2", - policy_version="1.0", - create_at=datetime.now(), - deleted=False, + return + + stats = defaultdict(lambda: {"records": 0, "groups": set()}) + for uid, group_id, rule_name in generate_v2_groups(candidates, num_dtypes): + session.add( + GraphNetSampleGroup( + sample_uid=uid, + group_uid=group_id, + group_type="ai4c", + group_policy="bucket_policy_v2", + policy_version="1.0", + create_at=datetime.now(), + deleted=False, + ) ) - session.add(new_group) - count += 1 + stats[rule_name]["records"] += 1 + stats[rule_name]["groups"].add(group_id) session.commit() - print(f" Inserted {count} v2 group records.") - return count + _print_stats(stats, ["rule4", "rule3"]) # ═══════════════════════════════════════════════════════════════════ @@ -286,33 +266,15 @@ def main(): default=3, help="Number of different dtypes to pick per shape in v2 (default: 3)", ) - parser.add_argument( - "--only_v1", - action="store_true", - help="Only run v1 (Rule 1 + Rule 2)", - ) - parser.add_argument( - "--only_v2", - action="store_true", - help="Only run v2 (Rule 3 + Rule 4), requires v1 already in DB", - ) - args = parser.parse_args() + db = DB(args.db_path) db.connect() - session = get_session(args.db_path) try: - run_v1 = not args.only_v2 - run_v2 = not args.only_v1 - - if run_v1: - insert_v1(db, session, args.db_path) - - if run_v2: - insert_v2(db, session, args.db_path, args.num_dtypes) - + insert_v1_groups(db, session) + insert_v2_groups(db, session, args.num_dtypes) except Exception: session.rollback() raise From 489af4392dca766412081c534a4d6cf81c59fbad Mon Sep 17 00:00:00 2001 From: Honglei-Qiu <1044497581@qq.com> Date: Tue, 24 Mar 2026 07:32:38 +0000 Subject: [PATCH 08/10] group by all --- sqlite/graph_net_sample_groups_insert.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/sqlite/graph_net_sample_groups_insert.py b/sqlite/graph_net_sample_groups_insert.py index 587fb56ec..c89d96260 100755 --- a/sqlite/graph_net_sample_groups_insert.py +++ b/sqlite/graph_net_sample_groups_insert.py @@ -38,7 +38,7 @@ def close(self): V2Candidate = namedtuple( "V2Candidate", - ["uid", "op_seq", "shapes", "dtypes"], + ["uid", "sample_type", "op_seq", "shapes", "dtypes"], ) @@ -76,12 +76,14 @@ def generate_v1_groups(bucket_groups: list[BucketGroup]): for uid in sampled: yield uid, _new_group_id(), "rule1" - # Rule 2: group heads by op_seq - op_seq_to_heads = defaultdict(list) + # Rule 2: group heads by (sample_type, op_seq) + type_op_seq_to_heads = defaultdict(list) for bucket in bucket_groups: - op_seq_to_heads[bucket.op_seq].append(bucket.head_uid) + type_op_seq_to_heads[(bucket.sample_type, bucket.op_seq)].append( + bucket.head_uid + ) - for heads in op_seq_to_heads.values(): + for heads in type_op_seq_to_heads.values(): group_id = _new_group_id() for uid in heads: yield uid, group_id, "rule2" @@ -150,12 +152,12 @@ def generate_v2_groups(candidates: list[V2Candidate], num_dtypes: int): candidates_by_op_seq = defaultdict(list) for c in candidates: - candidates_by_op_seq[c.op_seq].append(c) + candidates_by_op_seq[(c.sample_type, c.op_seq)].append(c) dtype_covered_uids = set() # --- Rule 4: dtype coverage (runs first) --- - for op_seq, op_candidates in candidates_by_op_seq.items(): + for key, op_candidates in candidates_by_op_seq.items(): candidates_by_shape = defaultdict(list) for c in op_candidates: candidates_by_shape[c.shapes].append(c) @@ -176,7 +178,7 @@ def generate_v2_groups(candidates: list[V2Candidate], num_dtypes: int): # --- Rule 3: global sparse sampling (on remaining candidates) --- window_size = num_dtypes * 5 - for op_seq, op_candidates in candidates_by_op_seq.items(): + for key, op_candidates in candidates_by_op_seq.items(): remaining = [c for c in op_candidates if c.uid not in dtype_covered_uids] remaining.sort(key=lambda c: c.uid) @@ -195,6 +197,7 @@ def query_v2_candidates(db: DB) -> list[V2Candidate]: sql = """ SELECT s.uuid, + s.sample_type, b.op_seq_bucket_id, b.input_shapes_bucket_id, b.input_dtypes_bucket_id @@ -208,7 +211,7 @@ def query_v2_candidates(db: DB) -> list[V2Candidate]: WHERE g.group_policy = 'bucket_policy_v1' AND g.deleted = 0 ) -ORDER BY b.op_seq_bucket_id, b.input_shapes_bucket_id, b.input_dtypes_bucket_id, s.uuid; +ORDER BY s.sample_type, b.op_seq_bucket_id, b.input_shapes_bucket_id, b.input_dtypes_bucket_id, s.uuid; """ return [V2Candidate(*row) for row in db.query(sql)] From 05e5d0ffca7a9168647443479956b1a1c547b276 Mon Sep 17 00:00:00 2001 From: Honglei-Qiu <1044497581@qq.com> Date: Tue, 24 Mar 2026 09:53:04 +0000 Subject: [PATCH 09/10] group by all --- sqlite/graph_net_sample_groups_insert.py | 92 +++++++++++++----------- 1 file changed, 50 insertions(+), 42 deletions(-) diff --git a/sqlite/graph_net_sample_groups_insert.py b/sqlite/graph_net_sample_groups_insert.py index c89d96260..3619ea6db 100755 --- a/sqlite/graph_net_sample_groups_insert.py +++ b/sqlite/graph_net_sample_groups_insert.py @@ -46,17 +46,22 @@ def _new_group_id(): return str(uuid_module.uuid4()) -def _print_stats(stats, rule_order): +def _print_stats(stats): + rule_order = ["rule1", "rule2", "rule4", "rule3"] + sample_types = sorted({st for st, _ in stats.keys()}) total_records = 0 total_groups = 0 - for rule_name in rule_order: - if rule_name in stats: - r = stats[rule_name]["records"] - g = len(stats[rule_name]["groups"]) - print(f" {rule_name}: {r} records, {g} groups") - total_records += r - total_groups += g - print(f" Total: {total_records} records, {total_groups} groups.") + for sample_type in sample_types: + print(f"\n [{sample_type}]") + for rule_name in rule_order: + key = (sample_type, rule_name) + if key in stats: + r = stats[key]["records"] + g = len(stats[key]["groups"]) + print(f" {rule_name}: {r} records, {g} groups") + total_records += r + total_groups += g + print(f"\n Total: {total_records} records, {total_groups} groups.") # ═══════════════════════════════════════════════════════════════════ @@ -67,14 +72,14 @@ def _print_stats(stats, rule_order): def generate_v1_groups(bucket_groups: list[BucketGroup]): """Rule 1: stride-5 sampling, each sampled uid gets its own group. Rule 2: group all bucket heads that share the same op_seq. - Yields (uid, group_id, rule_name).""" + Yields (sample_type, uid, group_id, rule_name).""" # Rule 1: stride-5 sampling for bucket in bucket_groups: members = bucket.all_uids_csv.split(",") - sampled = [uid for uid in members[::5] if uid != bucket.head_uid] + sampled = [uid for uid in members[::16] if uid != bucket.head_uid] for uid in sampled: - yield uid, _new_group_id(), "rule1" + yield bucket.sample_type, uid, _new_group_id(), "rule1" # Rule 2: group heads by (sample_type, op_seq) type_op_seq_to_heads = defaultdict(list) @@ -83,10 +88,10 @@ def generate_v1_groups(bucket_groups: list[BucketGroup]): bucket.head_uid ) - for heads in type_op_seq_to_heads.values(): + for (sample_type, _), heads in type_op_seq_to_heads.items(): group_id = _new_group_id() for uid in heads: - yield uid, group_id, "rule2" + yield sample_type, uid, group_id, "rule2" def query_v1_bucket_groups(db: DB) -> list[BucketGroup]: @@ -113,15 +118,11 @@ def query_v1_bucket_groups(db: DB) -> list[BucketGroup]: def insert_v1_groups(db: DB, session): - print("=" * 60) - print("V1: Rule 1 (stride-5) + Rule 2 (cross-shape aggregation)") - print("=" * 60) - bucket_groups = query_v1_bucket_groups(db) - print(f" Bucket groups: {len(bucket_groups)}") + print(f"Bucket groups: {len(bucket_groups)}") stats = defaultdict(lambda: {"records": 0, "groups": set()}) - for uid, group_id, rule_name in generate_v1_groups(bucket_groups): + for sample_type, uid, group_id, rule_name in generate_v1_groups(bucket_groups): session.add( GraphNetSampleGroup( sample_uid=uid, @@ -133,11 +134,11 @@ def insert_v1_groups(db: DB, session): deleted=False, ) ) - stats[rule_name]["records"] += 1 - stats[rule_name]["groups"].add(group_id) + stats[(sample_type, rule_name)]["records"] += 1 + stats[(sample_type, rule_name)]["groups"].add(group_id) session.commit() - _print_stats(stats, ["rule1", "rule2"]) + return stats # ═══════════════════════════════════════════════════════════════════ @@ -148,7 +149,7 @@ def insert_v1_groups(db: DB, session): def generate_v2_groups(candidates: list[V2Candidate], num_dtypes: int): """Rule 4 runs first to ensure full dtype coverage. Rule 3 then sparse-samples from the remaining candidates. - Yields (uid, group_id, rule_name).""" + Yields (sample_type, uid, group_id, rule_name).""" candidates_by_op_seq = defaultdict(list) for c in candidates: @@ -157,7 +158,7 @@ def generate_v2_groups(candidates: list[V2Candidate], num_dtypes: int): dtype_covered_uids = set() # --- Rule 4: dtype coverage (runs first) --- - for key, op_candidates in candidates_by_op_seq.items(): + for (sample_type, _), op_candidates in candidates_by_op_seq.items(): candidates_by_shape = defaultdict(list) for c in op_candidates: candidates_by_shape[c.shapes].append(c) @@ -174,11 +175,11 @@ def generate_v2_groups(candidates: list[V2Candidate], num_dtypes: int): if selected_uids: group_id = _new_group_id() for uid in selected_uids: - yield uid, group_id, "rule4" + yield sample_type, uid, group_id, "rule4" # --- Rule 3: global sparse sampling (on remaining candidates) --- window_size = num_dtypes * 5 - for key, op_candidates in candidates_by_op_seq.items(): + for (sample_type, _), op_candidates in candidates_by_op_seq.items(): remaining = [c for c in op_candidates if c.uid not in dtype_covered_uids] remaining.sort(key=lambda c: c.uid) @@ -190,7 +191,7 @@ def generate_v2_groups(candidates: list[V2Candidate], num_dtypes: int): if selected_uids: group_id = _new_group_id() for uid in selected_uids: - yield uid, group_id, "rule3" + yield sample_type, uid, group_id, "rule3" def query_v2_candidates(db: DB) -> list[V2Candidate]: @@ -217,19 +218,17 @@ def query_v2_candidates(db: DB) -> list[V2Candidate]: def insert_v2_groups(db: DB, session, num_dtypes: int): - print("=" * 60) - print("V2: Rule 4 (dtype coverage) + Rule 3 (sparse sampling)") - print("=" * 60) - candidates = query_v2_candidates(db) - print(f" V2 candidates: {len(candidates)}") + print(f"V2 candidates: {len(candidates)}") + stats = defaultdict(lambda: {"records": 0, "groups": set()}) if not candidates: - print(" No v2 candidates found. Skipping.") - return + print("No v2 candidates found. Skipping.") + return stats - stats = defaultdict(lambda: {"records": 0, "groups": set()}) - for uid, group_id, rule_name in generate_v2_groups(candidates, num_dtypes): + for sample_type, uid, group_id, rule_name in generate_v2_groups( + candidates, num_dtypes + ): session.add( GraphNetSampleGroup( sample_uid=uid, @@ -241,11 +240,11 @@ def insert_v2_groups(db: DB, session, num_dtypes: int): deleted=False, ) ) - stats[rule_name]["records"] += 1 - stats[rule_name]["groups"].add(group_id) + stats[(sample_type, rule_name)]["records"] += 1 + stats[(sample_type, rule_name)]["groups"].add(group_id) session.commit() - _print_stats(stats, ["rule4", "rule3"]) + return stats # ═══════════════════════════════════════════════════════════════════ @@ -276,8 +275,8 @@ def main(): session = get_session(args.db_path) try: - insert_v1_groups(db, session) - insert_v2_groups(db, session, args.num_dtypes) + v1_stats = insert_v1_groups(db, session) + v2_stats = insert_v2_groups(db, session, args.num_dtypes) except Exception: session.rollback() raise @@ -285,6 +284,15 @@ def main(): session.close() db.close() + # Merge and print + all_stats = defaultdict(lambda: {"records": 0, "groups": set()}) + for s in (v1_stats, v2_stats): + for key, val in s.items(): + all_stats[key]["records"] += val["records"] + all_stats[key]["groups"].update(val["groups"]) + + print("=" * 60) + _print_stats(all_stats) print("\nDone!") From f784701db2d64ab40805eae28ccfcb9b281a56cd Mon Sep 17 00:00:00 2001 From: Honglei-Qiu <1044497581@qq.com> Date: Wed, 25 Mar 2026 06:42:51 +0000 Subject: [PATCH 10/10] group by all --- sqlite/graph_net_sample_groups_insert.py | 322 +++++++++++------------ 1 file changed, 152 insertions(+), 170 deletions(-) diff --git a/sqlite/graph_net_sample_groups_insert.py b/sqlite/graph_net_sample_groups_insert.py index 3619ea6db..b9a68ab3a 100755 --- a/sqlite/graph_net_sample_groups_insert.py +++ b/sqlite/graph_net_sample_groups_insert.py @@ -1,100 +1,77 @@ import argparse import sqlite3 import uuid as uuid_module +from collections import defaultdict, namedtuple from datetime import datetime -from collections import namedtuple, defaultdict -from orm_models import ( - get_session, - GraphNetSampleGroup, -) - - -class DB: - def __init__(self, path): - self.path = path - - def connect(self): - self.conn = sqlite3.connect(self.path) - self.conn.row_factory = sqlite3.Row - self.cur = self.conn.cursor() - - def query(self, sql, params=None): - self.cur.execute(sql, params or ()) - return self.cur.fetchall() - - def close(self): - self.conn.close() +from orm_models import get_session, GraphNetSampleGroup -# ── V1 types ── +# ── Types ── BucketGroup = namedtuple( "BucketGroup", ["head_uid", "op_seq", "shapes", "sample_type", "all_uids_csv"], ) -# ── V2 types ── - -V2Candidate = namedtuple( - "V2Candidate", +Candidate = namedtuple( + "Candidate", ["uid", "sample_type", "op_seq", "shapes", "dtypes"], ) +# ── Helpers ── + + def _new_group_id(): return str(uuid_module.uuid4()) +def _merge_stats(dst, src): + for key, val in src.items(): + dst[key]["records"] += val["records"] + dst[key]["groups"].update(val["groups"]) + + def _print_stats(stats): rule_order = ["rule1", "rule2", "rule4", "rule3"] - sample_types = sorted({st for st, _ in stats.keys()}) + sample_types = sorted({st for st, _ in stats}) total_records = 0 total_groups = 0 for sample_type in sample_types: print(f"\n [{sample_type}]") - for rule_name in rule_order: - key = (sample_type, rule_name) + for rule in rule_order: + key = (sample_type, rule) if key in stats: - r = stats[key]["records"] - g = len(stats[key]["groups"]) - print(f" {rule_name}: {r} records, {g} groups") - total_records += r - total_groups += g + n_records = stats[key]["records"] + n_groups = len(stats[key]["groups"]) + print(f" {rule}: {n_records} records, {n_groups} groups") + total_records += n_records + total_groups += n_groups print(f"\n Total: {total_records} records, {total_groups} groups.") -# ═══════════════════════════════════════════════════════════════════ -# V1: Rule 1 (bucket-internal stride-5) + Rule 2 (cross-shape aggregation) -# ═══════════════════════════════════════════════════════════════════ +# ── Database Queries ── -def generate_v1_groups(bucket_groups: list[BucketGroup]): - """Rule 1: stride-5 sampling, each sampled uid gets its own group. - Rule 2: group all bucket heads that share the same op_seq. - Yields (sample_type, uid, group_id, rule_name).""" +class DB: + def __init__(self, path): + self.path = path - # Rule 1: stride-5 sampling - for bucket in bucket_groups: - members = bucket.all_uids_csv.split(",") - sampled = [uid for uid in members[::16] if uid != bucket.head_uid] - for uid in sampled: - yield bucket.sample_type, uid, _new_group_id(), "rule1" + def connect(self): + self.conn = sqlite3.connect(self.path) + self.conn.row_factory = sqlite3.Row + self.cursor = self.conn.cursor() - # Rule 2: group heads by (sample_type, op_seq) - type_op_seq_to_heads = defaultdict(list) - for bucket in bucket_groups: - type_op_seq_to_heads[(bucket.sample_type, bucket.op_seq)].append( - bucket.head_uid - ) + def query(self, sql, params=None): + self.cursor.execute(sql, params or ()) + return self.cursor.fetchall() - for (sample_type, _), heads in type_op_seq_to_heads.items(): - group_id = _new_group_id() - for uid in heads: - yield sample_type, uid, group_id, "rule2" + def close(self): + self.conn.close() -def query_v1_bucket_groups(db: DB) -> list[BucketGroup]: +def query_bucket_groups(db: DB) -> list[BucketGroup]: sql = """ SELECT sub.sample_uid, @@ -117,124 +94,129 @@ def query_v1_bucket_groups(db: DB) -> list[BucketGroup]: return [BucketGroup(*row) for row in db.query(sql)] -def insert_v1_groups(db: DB, session): - bucket_groups = query_v1_bucket_groups(db) - print(f"Bucket groups: {len(bucket_groups)}") +def query_v2_candidates(db: DB) -> list[Candidate]: + sql = """ +SELECT + s.uuid, + s.sample_type, + b.op_seq_bucket_id, + b.input_shapes_bucket_id, + b.input_dtypes_bucket_id +FROM graph_sample s +JOIN graph_net_sample_buckets b ON s.uuid = b.sample_uid +WHERE s.deleted = 0 + AND s.sample_type != 'full_graph' + AND s.uuid NOT IN ( + SELECT g.sample_uid + FROM graph_net_sample_groups g + WHERE g.group_policy = 'bucket_policy_v1' + AND g.deleted = 0 + ) +ORDER BY s.sample_type, b.op_seq_bucket_id, b.input_shapes_bucket_id, + b.input_dtypes_bucket_id, s.uuid; + """ + return [Candidate(*row) for row in db.query(sql)] - stats = defaultdict(lambda: {"records": 0, "groups": set()}) - for sample_type, uid, group_id, rule_name in generate_v1_groups(bucket_groups): - session.add( - GraphNetSampleGroup( - sample_uid=uid, - group_uid=group_id, - group_type="ai4c", - group_policy="bucket_policy_v1", - policy_version="1.0", - create_at=datetime.now(), - deleted=False, - ) - ) - stats[(sample_type, rule_name)]["records"] += 1 - stats[(sample_type, rule_name)]["groups"].add(group_id) - session.commit() - return stats +# ═══════════════════════════════════════════════════════════════════ +# V1: Rule 1 (bucket-internal stride sampling) + Rule 2 (cross-shape) +# ═══════════════════════════════════════════════════════════════════ + + +def generate_v1_groups(bucket_groups: list[BucketGroup]): + """Yields (sample_type, uid, group_id, rule_name). + + Rule 1: stride-16 sampling within each bucket, one group per sample. + Rule 2: aggregate all bucket heads sharing the same (sample_type, op_seq). + """ + # Rule 1 + for bucket in bucket_groups: + members = bucket.all_uids_csv.split(",") + for uid in members[::16]: + if uid != bucket.head_uid: + yield bucket.sample_type, uid, _new_group_id(), "rule1" + + # Rule 2 + heads_by_type_op = defaultdict(list) + for bucket in bucket_groups: + heads_by_type_op[(bucket.sample_type, bucket.op_seq)].append(bucket.head_uid) + for (sample_type, _op), heads in heads_by_type_op.items(): + gid = _new_group_id() + for uid in heads: + yield sample_type, uid, gid, "rule2" # ═══════════════════════════════════════════════════════════════════ -# V2: Rule 4 (dtype coverage) + Rule 3 (global sparse sampling) +# V2: Rule 4 (dtype coverage) + Rule 3 (sparse sampling on remainder) # ═══════════════════════════════════════════════════════════════════ -def generate_v2_groups(candidates: list[V2Candidate], num_dtypes: int): - """Rule 4 runs first to ensure full dtype coverage. - Rule 3 then sparse-samples from the remaining candidates. - Yields (sample_type, uid, group_id, rule_name).""" +def generate_v2_groups(candidates: list[Candidate], num_dtypes: int): + """Yields (sample_type, uid, group_id, rule_name). - candidates_by_op_seq = defaultdict(list) + Rule 4 (first): per (sample_type, op_seq, shape), pick up to + num_dtypes samples with distinct dtypes. + Rule 3 (second): window-based sparse sampling on the remainder, + window_size = num_dtypes * 5, pick first num_dtypes. + """ + by_type_op = defaultdict(list) for c in candidates: - candidates_by_op_seq[(c.sample_type, c.op_seq)].append(c) + by_type_op[(c.sample_type, c.op_seq)].append(c) - dtype_covered_uids = set() + covered_uids = set() - # --- Rule 4: dtype coverage (runs first) --- - for (sample_type, _), op_candidates in candidates_by_op_seq.items(): - candidates_by_shape = defaultdict(list) - for c in op_candidates: - candidates_by_shape[c.shapes].append(c) + # Rule 4: dtype coverage + for (sample_type, _op), group in by_type_op.items(): + by_shape = defaultdict(list) + for c in group: + by_shape[c.shapes].append(c) - selected_uids = [] - for shape, shape_candidates in candidates_by_shape.items(): + picked = [] + for _shape, shape_group in by_shape.items(): seen_dtypes = set() - for c in shape_candidates: + for c in shape_group: if c.dtypes not in seen_dtypes and len(seen_dtypes) < num_dtypes: seen_dtypes.add(c.dtypes) - selected_uids.append(c.uid) - dtype_covered_uids.add(c.uid) + picked.append(c.uid) + covered_uids.add(c.uid) - if selected_uids: - group_id = _new_group_id() - for uid in selected_uids: - yield sample_type, uid, group_id, "rule4" + if picked: + gid = _new_group_id() + for uid in picked: + yield sample_type, uid, gid, "rule4" - # --- Rule 3: global sparse sampling (on remaining candidates) --- + # Rule 3: sparse sampling on remainder window_size = num_dtypes * 5 - for (sample_type, _), op_candidates in candidates_by_op_seq.items(): - remaining = [c for c in op_candidates if c.uid not in dtype_covered_uids] - remaining.sort(key=lambda c: c.uid) - - selected_uids = [] - for idx, c in enumerate(remaining): - if (idx % window_size) < num_dtypes: - selected_uids.append(c.uid) - - if selected_uids: - group_id = _new_group_id() - for uid in selected_uids: - yield sample_type, uid, group_id, "rule3" - + for (sample_type, _op), group in by_type_op.items(): + remaining = sorted( + (c for c in group if c.uid not in covered_uids), + key=lambda c: c.uid, + ) + picked = [ + c.uid for i, c in enumerate(remaining) if (i % window_size) < num_dtypes + ] + if picked: + gid = _new_group_id() + for uid in picked: + yield sample_type, uid, gid, "rule3" -def query_v2_candidates(db: DB) -> list[V2Candidate]: - sql = """ -SELECT - s.uuid, - s.sample_type, - b.op_seq_bucket_id, - b.input_shapes_bucket_id, - b.input_dtypes_bucket_id -FROM graph_sample s -JOIN graph_net_sample_buckets b ON s.uuid = b.sample_uid -WHERE s.deleted = 0 - AND s.sample_type != 'full_graph' - AND s.uuid NOT IN ( - SELECT g.sample_uid - FROM graph_net_sample_groups g - WHERE g.group_policy = 'bucket_policy_v1' - AND g.deleted = 0 - ) -ORDER BY s.sample_type, b.op_seq_bucket_id, b.input_shapes_bucket_id, b.input_dtypes_bucket_id, s.uuid; - """ - return [V2Candidate(*row) for row in db.query(sql)] +# ═══════════════════════════════════════════════════════════════════ +# Insert +# ═══════════════════════════════════════════════════════════════════ -def insert_v2_groups(db: DB, session, num_dtypes: int): - candidates = query_v2_candidates(db) - print(f"V2 candidates: {len(candidates)}") +def _insert_groups(session, rows, policy): + """Consume a generator of (sample_type, uid, group_id, rule_name), + write to DB, and return per-(sample_type, rule) stats.""" stats = defaultdict(lambda: {"records": 0, "groups": set()}) - if not candidates: - print("No v2 candidates found. Skipping.") - return stats - - for sample_type, uid, group_id, rule_name in generate_v2_groups( - candidates, num_dtypes - ): + for sample_type, uid, group_id, rule_name in rows: session.add( GraphNetSampleGroup( sample_uid=uid, group_uid=group_id, group_type="ai4c", - group_policy="bucket_policy_v2", + group_policy=policy, policy_version="1.0", create_at=datetime.now(), deleted=False, @@ -242,7 +224,6 @@ def insert_v2_groups(db: DB, session, num_dtypes: int): ) stats[(sample_type, rule_name)]["records"] += 1 stats[(sample_type, rule_name)]["groups"].add(group_id) - session.commit() return stats @@ -256,27 +237,35 @@ def main(): parser = argparse.ArgumentParser( description="Generate graph_net_sample_groups (v1 + v2)" ) - parser.add_argument( - "--db_path", - type=str, - required=True, - help="Path to the SQLite database file", - ) - parser.add_argument( - "--num_dtypes", - type=int, - default=3, - help="Number of different dtypes to pick per shape in v2 (default: 3)", - ) + parser.add_argument("--db_path", type=str, required=True) + parser.add_argument("--num_dtypes", type=int, default=3) args = parser.parse_args() db = DB(args.db_path) db.connect() session = get_session(args.db_path) + all_stats = defaultdict(lambda: {"records": 0, "groups": set()}) + try: - v1_stats = insert_v1_groups(db, session) - v2_stats = insert_v2_groups(db, session, args.num_dtypes) + # V1 + buckets = query_bucket_groups(db) + print(f"Bucket groups: {len(buckets)}") + v1 = _insert_groups(session, generate_v1_groups(buckets), "bucket_policy_v1") + _merge_stats(all_stats, v1) + + # V2 + candidates = query_v2_candidates(db) + print(f"V2 candidates: {len(candidates)}") + if candidates: + v2 = _insert_groups( + session, + generate_v2_groups(candidates, args.num_dtypes), + "bucket_policy_v2", + ) + _merge_stats(all_stats, v2) + else: + print("No V2 candidates found. Skipping.") except Exception: session.rollback() raise @@ -284,13 +273,6 @@ def main(): session.close() db.close() - # Merge and print - all_stats = defaultdict(lambda: {"records": 0, "groups": set()}) - for s in (v1_stats, v2_stats): - for key, val in s.items(): - all_stats[key]["records"] += val["records"] - all_stats[key]["groups"].update(val["groups"]) - print("=" * 60) _print_stats(all_stats) print("\nDone!")