From 2fe8d81d2b65ffeb185d50414e0c7d6903b61c22 Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Fri, 27 Mar 2026 11:17:46 +0800 Subject: [PATCH] Fix comparation when outputs are no different devices. --- graph_net_bench/torch/test_compiler.py | 34 ++++++++++++++------------ 1 file changed, 18 insertions(+), 16 deletions(-) diff --git a/graph_net_bench/torch/test_compiler.py b/graph_net_bench/torch/test_compiler.py index e41d9aa42..243ccdcd1 100755 --- a/graph_net_bench/torch/test_compiler.py +++ b/graph_net_bench/torch/test_compiler.py @@ -308,22 +308,21 @@ def print_and_store_cmp(key, cmp_func, args, expected_out, compiled_out, **kwarg def compare_correctness(expected_out, compiled_out, args): - eager_dtypes = [ - ( - str(x.dtype).replace("torch.", "") - if isinstance(x, torch.Tensor) - else type(x).__name__ - ) - for x in expected_out - ] - compiled_dtypes = [ - ( - str(x.dtype).replace("torch.", "") - if isinstance(x, torch.Tensor) - else type(x).__name__ - ) - for x in compiled_out - ] + def _get_output_dtypes(outs): + return [ + ( + str(x.dtype).replace("torch.", "") + if isinstance(x, torch.Tensor) + else type(x).__name__ + ) + for x in outs + ] + + def _align_output_device(outs, device): + return [x.to(device) if x.device != device else x for x in outs] + + eager_dtypes = _get_output_dtypes(expected_out) + compiled_dtypes = _get_output_dtypes(compiled_out) # datatype check type_match = test_compiler_util.check_output_datatype( @@ -331,6 +330,9 @@ def compare_correctness(expected_out, compiled_out, args): ) if type_match: + expected_out = _align_output_device(expected_out, args.device) + compiled_out = _align_output_device(compiled_out, args.device) + test_compiler_util.check_equal( args, expected_out,