diff --git a/graph_net_bench/torch/test_compiler.py b/graph_net_bench/torch/test_compiler.py index e41d9aa42..243ccdcd1 100755 --- a/graph_net_bench/torch/test_compiler.py +++ b/graph_net_bench/torch/test_compiler.py @@ -308,22 +308,21 @@ def print_and_store_cmp(key, cmp_func, args, expected_out, compiled_out, **kwarg def compare_correctness(expected_out, compiled_out, args): - eager_dtypes = [ - ( - str(x.dtype).replace("torch.", "") - if isinstance(x, torch.Tensor) - else type(x).__name__ - ) - for x in expected_out - ] - compiled_dtypes = [ - ( - str(x.dtype).replace("torch.", "") - if isinstance(x, torch.Tensor) - else type(x).__name__ - ) - for x in compiled_out - ] + def _get_output_dtypes(outs): + return [ + ( + str(x.dtype).replace("torch.", "") + if isinstance(x, torch.Tensor) + else type(x).__name__ + ) + for x in outs + ] + + def _align_output_device(outs, device): + return [x.to(device) if x.device != device else x for x in outs] + + eager_dtypes = _get_output_dtypes(expected_out) + compiled_dtypes = _get_output_dtypes(compiled_out) # datatype check type_match = test_compiler_util.check_output_datatype( @@ -331,6 +330,9 @@ def compare_correctness(expected_out, compiled_out, args): ) if type_match: + expected_out = _align_output_device(expected_out, args.device) + compiled_out = _align_output_device(compiled_out, args.device) + test_compiler_util.check_equal( args, expected_out,