Conversation
|
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
This function lists all the variants that are available on the current system.
c53cde0 to
b8a27cb
Compare
|
|
||
| version: Version | None | ||
| _VARIANT_REGEX: ClassVar[re.Pattern] = re.compile( | ||
| r"torch(\d+?)(\d+)(?:-(cxx11|cxx98))?" |
There was a problem hiding this comment.
Prepare for variants without an ABI tag (see extra variants in tests).
There was a problem hiding this comment.
Here I think one liner explainer comment would be nice. Bits around (cxx11|cxx98) aren't particularly clear I think.
|
|
||
| version: Version | None | ||
| _VARIANT_REGEX: ClassVar[re.Pattern] = re.compile( | ||
| r"torch(\d+?)(\d+)(?:-(cxx11|cxx98))?" |
There was a problem hiding this comment.
Here I think one liner explainer comment would be nice. Bits around (cxx11|cxx98) aren't particularly clear I think.
| cxx11_abi = torch.compiled_with_cxx11_abi() | ||
| return [ | ||
| Torch(version=torch_version, cxx11_abi=cxx11_abi), | ||
| Torch(version=torch_version, cxx11_abi=None), |
There was a problem hiding this comment.
Curious why include Torch(version=torch_version, cxx11_abi=None) here as well?
| def possible_variants() -> list["ArchVariant"]: | ||
| frameworks: list[Torch | TvmFfi] = ( | ||
| Torch.possible_variants() + TvmFfi.possible_variants() | ||
| ) | ||
| archs = Arch.possible_variants() | ||
| return [ | ||
| ArchVariant(framework=fw, arch=arch) | ||
| for fw, arch in itertools.product(frameworks, archs) | ||
| ] |
There was a problem hiding this comment.
This is so neat! As declarative as it gets.
| if parts[0] == "torch": | ||
| # noarch: e.g. "torch-cpu" | ||
| return NoarchVariant( | ||
| framework=TorchNoarch(), arch=Noarch.parse("-".join(parts[1:])) | ||
| ) | ||
| elif parts[0].startswith("torch"): | ||
| if len(parts) >= 2 and parts[1] in ("cxx11", "cxx98"): | ||
| framework_str = f"{parts[0]}-{parts[1]}" | ||
| arch_parts = parts[2:] | ||
| else: | ||
| framework_str = parts[0] | ||
| arch_parts = parts[1:] | ||
| return ArchVariant( | ||
| framework=Torch.parse(framework_str), arch=Arch.parse(arch_parts) | ||
| ) | ||
| elif parts[0] == "tvm" and len(parts) >= 2 and parts[1].startswith("ffi"): | ||
| return ArchVariant( | ||
| framework=TvmFfi.parse(f"tvm-{parts[1]}"), arch=Arch.parse(parts[2:]) | ||
| ) |
There was a problem hiding this comment.
This conditional logic feels a bit rigid, but I guess it's purely heuristics-driven and isn't meant to change? If it changes then there's something wrong.
| framework = TvmFfi.parse(f"tvm-{parts[1]}") | ||
| arch = Arch.parse(parts[2:]) | ||
| else: | ||
| raise ValueError(f"Unknown framework in variant string: {variant_str!r}") |
There was a problem hiding this comment.
So, we're removing parse() from the variant dataclasses and keeping a central parse_variant() utility?
| """Parse a variant string into an ArchVariant or NoarchVariant.""" | ||
| parts = variant_str.split("-") | ||
|
|
||
| if parts[0] == "torch": |
There was a problem hiding this comment.
Should there be a constraint on the length of parts in case of NoarchVariant because ArchVariant will also have parts[0] == "torch"?
| system_variants, | ||
| ) | ||
|
|
||
| VARIANT_STRINGS = [ |
There was a problem hiding this comment.
It could be nice to programatically construct this list against a predefined set of framework (+ versions), platform, arch, backend, etc.
This function lists all the variants that are available on the current system.