Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 74 additions & 46 deletions deployment/exporters/variance_exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,33 +388,47 @@ def _torch_export_model(self):
noise = torch.randn(shape, device=self.device)
condition = torch.rand((1, hparams['hidden_size'], 15), device=self.device)

print(f'Tracing {self.pitch_backbone_class_name} backbone...')
pitch_predictor = self.model.view_as_pitch_predictor()
pitch_predictor.pitch_predictor.set_backbone(
torch.jit.trace(
pitch_predictor.pitch_predictor.backbone,
(
noise,
dummy_time,
condition

if torch.__version__.startswith('1.13.'):
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The PyTorch version gate (torch.__version__.startswith('1.13.')) is duplicated in multiple blocks in this method. Consider computing a single boolean once (e.g., is_torch_113) near the top of _torch_export_model and reusing it to avoid repeating version parsing logic and to make future adjustments less error-prone.

Suggested change
if torch.__version__.startswith('1.13.'):
is_torch_113 = torch.__version__.startswith('1.13.')
if is_torch_113:

Copilot uses AI. Check for mistakes.
print(f'Tracing {self.pitch_backbone_class_name} backbone...')
pitch_predictor.pitch_predictor.set_backbone(
torch.jit.trace(
pitch_predictor.pitch_predictor.backbone,
(
noise,
dummy_time,
condition
)
)
)
)

print(f'Scripting {self.pitch_predictor_class_name}...')
pitch_predictor = torch.jit.script(
pitch_predictor,
example_inputs=[
(
condition.transpose(1, 2),
1 # p_sample branch
),
(
condition.transpose(1, 2),
dummy_steps # p_sample_plms branch
)
]
)
print(f'Scripting {self.pitch_predictor_class_name}...')
pitch_predictor = torch.jit.script(
pitch_predictor,
example_inputs=[
(
condition.transpose(1, 2),
1 # p_sample branch
),
(
condition.transpose(1, 2),
dummy_steps # p_sample_plms branch
)
]
)
else:
print(f'Wrapping {self.pitch_predictor_class_name} for trace-based export...')

class _PitchPredWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.pitch_predictor = model.pitch_predictor

def forward(self, pitch_cond, steps):
return self.pitch_predictor(pitch_cond, steps=steps)

pitch_predictor = _PitchPredWrapper(pitch_predictor)

print(f'Exporting {self.pitch_predictor_class_name}...')
torch.onnx.export(
Expand Down Expand Up @@ -535,33 +549,47 @@ def _torch_export_model(self):
condition = torch.rand((1, hparams['hidden_size'], 15), device=self.device)
step = (torch.rand((1,), device=self.device) * hparams['K_step']).long()

print(f'Tracing {self.variance_backbone_class_name} backbone...')
multi_var_predictor = self.model.view_as_variance_predictor()
multi_var_predictor.variance_predictor.set_backbone(
torch.jit.trace(
multi_var_predictor.variance_predictor.backbone,
(
noise,
step,
condition

if torch.__version__.startswith('1.13.'):
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same version check duplication as above: this block repeats torch.__version__.startswith('1.13.'). To reduce maintenance overhead, reuse a single precomputed flag (e.g., is_torch_113) within _torch_export_model.

Suggested change
if torch.__version__.startswith('1.13.'):
if is_torch_113:

Copilot uses AI. Check for mistakes.
print(f'Tracing {self.variance_backbone_class_name} backbone...')
multi_var_predictor.variance_predictor.set_backbone(
torch.jit.trace(
multi_var_predictor.variance_predictor.backbone,
(
noise,
step,
condition
)
)
)
)

print(f'Scripting {self.multi_var_predictor_class_name}...')
multi_var_predictor = torch.jit.script(
multi_var_predictor,
example_inputs=[
(
condition.transpose(1, 2),
1 # p_sample branch
),
(
condition.transpose(1, 2),
dummy_steps # p_sample_plms branch
)
]
)
print(f'Scripting {self.multi_var_predictor_class_name}...')
multi_var_predictor = torch.jit.script(
multi_var_predictor,
example_inputs=[
(
condition.transpose(1, 2),
1 # p_sample branch
),
(
condition.transpose(1, 2),
dummy_steps # p_sample_plms branch
)
]
)
else:
print(f'Wrapping {self.multi_var_predictor_class_name} for trace-based export...')

class _VarPredWrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.variance_predictor = model.variance_predictor

def forward(self, variance_cond, steps):
return self.variance_predictor(variance_cond, steps=steps)

multi_var_predictor = _VarPredWrapper(multi_var_predictor)

print(f'Exporting {self.multi_var_predictor_class_name}...')
torch.onnx.export(
Expand Down
7 changes: 5 additions & 2 deletions scripts/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@


def check_pytorch_version():
# Require PyTorch version to be exactly 1.13.x
import warnings
if torch.__version__.startswith('1.13.'):
return
raise RuntimeError('This script requires PyTorch 1.13.x. Please install the correct version.')
warnings.warn(
f'ONNX export is tested on PyTorch 1.13.x, but you have {torch.__version__}. '
f'Proceeding with trace-based fallback for variance models.'
Copy link

Copilot AI Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The warning text is misleading because check_pytorch_version() runs unconditionally for all export commands (acoustic/variance/vocoder), but the message claims it is "Proceeding with trace-based fallback for variance models." Consider either (1) making the warning generic (e.g., "Export is tested on 1.13.x...") or (2) moving this version warning into the variance export command / exporter so it only appears when the fallback is actually relevant.

Suggested change
f'Proceeding with trace-based fallback for variance models.'
f'Export may not behave as expected with this PyTorch version.'

Copilot uses AI. Check for mistakes.
)


def find_exp(exp):
Expand Down
Loading