Skip to content

ENH: testing.lazy_xp_function: torch.compile support#668

Draft
lucascolley wants to merge 4 commits intodata-apis:mainfrom
lucascolley:torch-autojit
Draft

ENH: testing.lazy_xp_function: torch.compile support#668
lucascolley wants to merge 4 commits intodata-apis:mainfrom
lucascolley:torch-autojit

Conversation

@lucascolley
Copy link
Copy Markdown
Member

@lucascolley lucascolley commented Apr 3, 2026

Closes gh-664

@rgommers this was easier than I expected. Is there someone familiar with Dynamo that we could ping? A few things that came up:

  • The following deprecation warning with a confusing message given that we are already using torch.compile:
E           torch._dynamo.exc.BackendCompilerFailed: backend='inductor' raised:
E           DeprecationWarning: `torch.jit.script_method` is not supported in Python 3.14+ and may break. Please switch to `torch.compile` or `torch.export`.
  • Some failing tests here with:
E                       torch._dynamo.exc.FailOnRecompileLimitHit: recompile_limit reached with fullgraph=True. Excessive recompilations can degrade performance due to the compilation overhead of each recompilation. To monitor recompilations, enable TORCH_LOGS=recompiles. If recompilations are expected, consider increasing torch._dynamo.config.cache_size_limit to an appropriate value.

(Draft PR as this will need docs updates before merge.)

Comment on lines +534 to 536
if jit_library is JitLibrary.jax and isinstance(obj, Iterator):
self._obj = list(obj)
self._is_iter = True
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

claude reckoned that, unlike JAX, we needn't treat iterables with this special case for torch.compile

Comment on lines +424 to +436
wrapped = autojit(func, JitLibrary.jax)
# If we're dealing with a staticmethod or classmethod, make
# sure things stay that way.
if isinstance(attr, staticmethod):
wrapped = staticmethod(wrapped)
elif isinstance(attr, classmethod):
wrapped = classmethod(wrapped)
temp_setattr(target, name, wrapped)

elif is_torch_namespace(xp):
for target, name, attr, func, tags in iter_tagged():
if tags["torch_compile"]:
wrapped = autojit(func, JitLibrary.torch)
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(minor: could reduce some LoC perhaps)

@lucascolley
Copy link
Copy Markdown
Member Author

@ev-br FYI

@lucascolley lucascolley requested a review from rgommers April 3, 2026 12:03
Copy link
Copy Markdown
Member Author

@lucascolley lucascolley left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the warnings about to flood CI should be fixed by data-apis/array-api-compat#411 (comment).

Comment on lines +90 to +91
[tool.pixi.pypi-dependencies]
array-api-compat = { git = "https://github.com/data-apis/array-api-compat", branch = "lucascolley-patch-1" }
Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's see if data-apis/array-api-compat#413 is all we need to change in array-api-compat...

@lucascolley
Copy link
Copy Markdown
Member Author

ready for review

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

ENH: testing.lazy_xp_function: torch.compile support?

1 participant