Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion graph_net/paddle/backend/cinn_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ def __call__(self, model, input_spec=None):
full_graph=True,
)
compiled_model.eval()
program = compiled_model.forward.concrete_program.main_program
return compiled_model

def synchronize(self):
Expand Down
10 changes: 0 additions & 10 deletions graph_net/paddle/check_redundant_incrementally.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,6 @@
from . import utils
import argparse
import importlib.util
import inspect
from pathlib import Path
from typing import Type, Any
import sys
import os
import os.path
from dataclasses import dataclass
from contextlib import contextmanager
import time
import glob


def get_recursively_model_pathes(root_dir):
Expand Down
10 changes: 0 additions & 10 deletions graph_net/paddle/remove_redundant_incrementally.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,6 @@
from . import utils
import argparse
import importlib.util
import inspect
from pathlib import Path
from typing import Type, Any
import sys
import os
import os.path
from dataclasses import dataclass
from contextlib import contextmanager
import time
import glob
import shutil


Expand Down
7 changes: 0 additions & 7 deletions graph_net/paddle/validate.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,8 @@
from . import utils
import argparse
import importlib.util
import inspect
from pathlib import Path
from typing import Type, Any
import sys
import hashlib
from contextlib import contextmanager
from collections import ChainMap
import numpy as np
import graph_net
import os
import ast
Expand Down Expand Up @@ -36,7 +30,6 @@ def _extract_forward_source(model_path, class_name):
source = f.read()

tree = ast.parse(source)
forward_code = None

for node in tree.body:
if isinstance(node, ast.ClassDef) and node.name == class_name:
Expand Down
1 change: 0 additions & 1 deletion graph_net/test/bert_model_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import torch
from transformers import AutoModel, AutoTokenizer
import graph_net.torch
import os


def get_model_name():
Expand Down
11 changes: 0 additions & 11 deletions graph_net/torch/check_redundant_incrementally.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,6 @@
from . import utils
import argparse
import importlib.util
import inspect
import torch
from pathlib import Path
from typing import Type, Any
import sys
import os
import os.path
from dataclasses import dataclass
from contextlib import contextmanager
import time
import glob


def get_recursively_model_pathes(root_dir):
Expand Down
4 changes: 3 additions & 1 deletion graph_net/torch/dim_gen_passes/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from graph_net.torch.dim_gen_passes.pass_base import DimensionGeneralizationPass
from graph_net.torch.dim_gen_passes.pass_base import (
DimensionGeneralizationPass as DimensionGeneralizationPass,
)
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import torch
import torch.fx as fx
from graph_net.torch.dim_gen_passes import DimensionGeneralizationPass
import os
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import torch
import torch.fx as fx
from graph_net.torch.dim_gen_passes import DimensionGeneralizationPass
from collections import namedtuple
Expand Down
3 changes: 0 additions & 3 deletions graph_net/torch/dim_gen_passes/pass_base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,4 @@
import torch
import torch.fx as fx
import inspect
import os


class DimensionGeneralizationPass:
Expand Down
2 changes: 1 addition & 1 deletion graph_net/torch/dim_gen_passes/pass_mgr.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os


def get_dim_gen_pass(pass_name) -> "DimensionGeneralizationPass":
def get_dim_gen_pass(pass_name) -> "DimensionGeneralizationPass": # noqa: F821
import graph_net.torch.dim_gen_passes as dgpass

py_module = load_module(f"{os.path.dirname(dgpass.__file__)}/{pass_name}.py")
Expand Down
3 changes: 1 addition & 2 deletions graph_net/torch/dump_graph_hash.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import tempfile
import sys
import contextlib
import graph_net


@contextlib.contextmanager
Expand All @@ -17,7 +16,7 @@ def temp_workspace():

def main(args):
model_path = args.model_path
with temp_workspace() as tmp_dir_name:
with temp_workspace():
print("dump-graph-hash ...")
extract_name = "temp"
cmd = f"{sys.executable} -m graph_net.torch.single_device_runner --model-path {model_path} --enable-extract True --extract-name {extract_name} --dump-graph-hash-key"
Expand Down
11 changes: 0 additions & 11 deletions graph_net/torch/remove_redundant_incrementally.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,6 @@
from . import utils
import argparse
import importlib.util
import inspect
import torch
from pathlib import Path
from typing import Type, Any
import sys
import os
import os.path
from dataclasses import dataclass
from contextlib import contextmanager
import time
import glob
import shutil


Expand Down
1 change: 0 additions & 1 deletion graph_net/torch/rp_expr/longest_rp_expr_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from graph_net.torch.rp_expr.rp_expr_parser import RpExprParser
from graph_net.torch.rp_expr.rp_expr import PrimitiveId, LetsListTokenRpExpr
import numpy as np
import sys


class LongestRpExprParser:
Expand Down
3 changes: 1 addition & 2 deletions graph_net/torch/rp_expr/rp_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import numpy as np
import torch
from collections import defaultdict
import functools

PrimitiveId = t.TypeVar("PrimitiveId")

Expand Down Expand Up @@ -428,7 +427,7 @@ def ValueToString(token_id):
]
)
yield from [
f"def main():",
"def main():",
*[
f" {SymbolToString(int(t[0]))}(){end_of_line}"
for t in self.body_rp_expr
Expand Down
2 changes: 0 additions & 2 deletions graph_net/torch/rp_expr/rp_expr_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,7 @@
from graph_net.torch.rp_expr.rp_expr import Tokenize, PrimitiveId, LetsListTokenRpExpr
from graph_net.torch.rp_expr.rp_expr_passes import (
FlattenTokenListPass,
FoldTokensPass,
RecursiveFoldTokensPass,
FoldIfTokenIdGreatEqualPass,
UnflattenAndSubThresholdPass,
)

Expand Down
6 changes: 1 addition & 5 deletions graph_net/torch/rp_expr/rp_expr_passes.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
from dataclasses import dataclass
import typing as t
import numpy as np
import re
import itertools
import torch
import torch.nn.functional as F
import math
from graph_net.torch.rp_expr.rp_expr import (
TokenIdAllocator,
NaiveTokenListRpExpr,
Expand All @@ -14,7 +11,6 @@
LetsTokenRpExpr,
LetsListTokenRpExpr,
)
import itertools
import sys


Expand Down Expand Up @@ -167,7 +163,7 @@ def GetWeight():
conv_weight = torch.cat(
[GetWeight() for _ in range(self.random_feature_size)], dim=1
)
conv = lambda input: F.conv1d(input, conv_weight, padding=0)
conv = lambda input: F.conv1d(input, conv_weight, padding=0) # noqa: E731
return conv, windows_size

def GetDisjoint(self, gap, indexes):
Expand Down
1 change: 0 additions & 1 deletion graph_net/torch/rp_expr/rp_expr_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
from graph_net.torch.rp_expr.rp_expr import LetsListTokenRpExpr
from graph_net.torch.rp_expr.nested_range import Range, Tree
from collections import defaultdict
from dataclasses import dataclass


def MakeNestedIndexRangeFromLetsListTokenRpExpr(
Expand Down
3 changes: 1 addition & 2 deletions graph_net/torch/shape_prop.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import torch
from typing import Union, Callable
from torch.fx.passes.shape_prop import ShapeProp
import inspect
import os
Expand Down Expand Up @@ -33,4 +32,4 @@ def forward(self, *args, **kwargs):
inputs = [
kwargs[name] for name in inspect.signature(self.module.forward).parameters
]
propagated_model = ShapeProp(traced_model).propagate(*inputs)
ShapeProp(traced_model).propagate(*inputs)
5 changes: 1 addition & 4 deletions graph_net/torch/single_device_runner.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
from . import utils
import argparse
import importlib.util
import inspect
import torch
import logging
from pathlib import Path
from typing import Type, Any
import sys
from typing import Type
from graph_net.torch.extractor import extract
import hashlib
import json
Expand Down
9 changes: 2 additions & 7 deletions graph_net/torch/typical_sequence_split_points.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
import argparse
import json
import os
from pathlib import Path
from typing import Any, Dict, List
import torch
import torch.nn as nn
from typing import Dict, List

from graph_net.torch.rp_expr.rp_expr_parser import RpExprParser
from graph_net.torch.rp_expr.rp_expr_util import (
MakeNestedIndexRangeFromLetsListTokenRpExpr,
)
from graph_net.torch.fx_graph_module_util import get_torch_module_and_inputs
from graph_net.torch.fx_graph_parse_util import parse_sole_graph_module_without_varify


class SplitAnalyzer:
Expand Down Expand Up @@ -365,4 +360,4 @@ def _get_model_subgraph_ranges(
help="maximum number of sequence operators",
)
args = parser.parse_args()
main(args)
main(args)
5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,7 @@ include = [
"LICENSE",
]


[tool.ruff.lint.per-file-ignores]
# F821: numpy string annotations like np.ndarray["N", np.int64] trigger false positives
"graph_net/torch/rp_expr/rp_expr.py" = ["F821"]
"graph_net/torch/rp_expr/rp_expr_passes.py" = ["F821"]
Loading