Support variance ONNX export on PyTorch 2.x#293
Support variance ONNX export on PyTorch 2.x#293IDExpensive-One wants to merge 1 commit intoopenvpi:mainfrom
Conversation
The variance model export previously required PyTorch 1.13.x because it used torch.jit.script on models returned by view_as_*_predictor(). This fails on PyTorch 2.x with "Unsupported value kind: Tensor" as TorchScript tries to compile all methods on the class, including those referencing deleted attributes. This adds a fallback path for PyTorch 2.x that uses lightweight wrapper modules with trace-based torch.onnx.export instead of torch.jit.script. The original script-based path is preserved for PyTorch 1.13.x. The version check in export.py is relaxed from a hard error to a warning.
There was a problem hiding this comment.
Pull request overview
Enables ONNX export of the variance (and pitch) predictor models on PyTorch 2.x by avoiding TorchScript scripting of view_as_*_predictor() outputs, while keeping the existing TorchScript-based export flow for PyTorch 1.13.x.
Changes:
- Relax PyTorch 1.13.x hard requirement in the CLI export script to a warning.
- Add a PyTorch 2.x path in the variance exporter that wraps predictor submodules for trace-based ONNX export instead of
torch.jit.script. - Preserve the PyTorch 1.13.x path that traces the backbone and scripts the predictor before exporting.
Reviewed changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.
| File | Description |
|---|---|
| scripts/export.py | Replaces the strict PyTorch 1.13.x requirement with a warning to allow running exports on newer versions. |
| deployment/exporters/variance_exporter.py | Adds a PyTorch-version-dependent export path: TorchScript scripting on 1.13.x, wrapper-based trace export otherwise. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| raise RuntimeError('This script requires PyTorch 1.13.x. Please install the correct version.') | ||
| warnings.warn( | ||
| f'ONNX export is tested on PyTorch 1.13.x, but you have {torch.__version__}. ' | ||
| f'Proceeding with trace-based fallback for variance models.' |
There was a problem hiding this comment.
The warning text is misleading because check_pytorch_version() runs unconditionally for all export commands (acoustic/variance/vocoder), but the message claims it is "Proceeding with trace-based fallback for variance models." Consider either (1) making the warning generic (e.g., "Export is tested on 1.13.x...") or (2) moving this version warning into the variance export command / exporter so it only appears when the fallback is actually relevant.
| f'Proceeding with trace-based fallback for variance models.' | |
| f'Export may not behave as expected with this PyTorch version.' |
| dummy_time, | ||
| condition | ||
|
|
||
| if torch.__version__.startswith('1.13.'): |
There was a problem hiding this comment.
The PyTorch version gate (torch.__version__.startswith('1.13.')) is duplicated in multiple blocks in this method. Consider computing a single boolean once (e.g., is_torch_113) near the top of _torch_export_model and reusing it to avoid repeating version parsing logic and to make future adjustments less error-prone.
| if torch.__version__.startswith('1.13.'): | |
| is_torch_113 = torch.__version__.startswith('1.13.') | |
| if is_torch_113: |
| step, | ||
| condition | ||
|
|
||
| if torch.__version__.startswith('1.13.'): |
There was a problem hiding this comment.
Same version check duplication as above: this block repeats torch.__version__.startswith('1.13.'). To reduce maintenance overhead, reuse a single precomputed flag (e.g., is_torch_113) within _torch_export_model.
| if torch.__version__.startswith('1.13.'): | |
| if is_torch_113: |
Summary
The variance model ONNX export currently requires PyTorch 1.13.x because
torch.jit.scriptis used on models returned byview_as_*_predictor(). On PyTorch 2.x, this fails withRuntimeError: Unsupported value kind: Tensorbecause TorchScript attempts to compile all methods on the class, including those that reference attributes removed bydeepcopy+delin theview_as_*methods.This PR:
torch.onnx.exporttorch.jit.scriptpath for PyTorch 1.13.xexport.pyto a warningTested on