Skip to content

Commit dba3416

Browse files
Aaron MeyerAaron Meyer
authored andcommitted
Fix bug from renaming variables
1 parent ef6830d commit dba3416

File tree

1 file changed

+37
-30
lines changed

1 file changed

+37
-30
lines changed

tensorly/utils/jointdiag.py

Lines changed: 37 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def joint_matrix_diagonalization(
7676
quadratic convergence rate.
7777
7878
Args:
79-
X (_type_): n matrices, organized in a single tensor of dimension (k, k, n).
79+
matrices_tensor (_type_): n matrices, organized in a single tensor of dimension (k, k, n).
8080
max_n_iter (int, optional): Maximum iteration number. Defaults to 50.
8181
threshold (float, optional): Threshold for decrease in deviation indicating convergence. Defaults to 1e-8.
8282
verbose (bool, optional): Output progress information during diagonalization. Defaults to False.
@@ -139,30 +139,35 @@ def joint_matrix_diagonalization(
139139
)
140140

141141
# Inverse of Sk on left side
142-
pvec = tl.copy(matrices_tensor[p, :, :])
143-
X = tl.index_update(
142+
prev_vec = tl.copy(matrices_tensor[p, :, :])
143+
matrices_tensor = tl.index_update(
144144
matrices_tensor,
145145
tl.index[p, :, :],
146146
matrices_tensor[p, :, :] * tl.cosh(yk)
147147
- matrices_tensor[q, :, :] * tl.sinh(yk),
148148
)
149-
X = tl.index_update(
150-
X, tl.index[q, :, :], -pvec * tl.sinh(yk) + X[q, :, :] * tl.cosh(yk)
149+
matrices_tensor = tl.index_update(
150+
matrices_tensor,
151+
tl.index[q, :, :],
152+
-prev_vec * tl.sinh(yk) + matrices_tensor[q, :, :] * tl.cosh(yk),
151153
)
152154

153155
# Sk on right side
154-
pvec = tl.copy(X[:, p, :])
155-
X = tl.index_update(
156-
X,
156+
prev_vec = tl.copy(matrices_tensor[:, p, :])
157+
matrices_tensor = tl.index_update(
158+
matrices_tensor,
157159
tl.index[:, p, :],
158-
X[:, p, :] * tl.cosh(yk) + X[:, q, :] * tl.sinh(yk),
160+
matrices_tensor[:, p, :] * tl.cosh(yk)
161+
+ matrices_tensor[:, q, :] * tl.sinh(yk),
159162
)
160-
X = tl.index_update(
161-
X, tl.index[:, q, :], pvec * tl.sinh(yk) + X[:, q, :] * tl.cosh(yk)
163+
matrices_tensor = tl.index_update(
164+
matrices_tensor,
165+
tl.index[:, q, :],
166+
prev_vec * tl.sinh(yk) + matrices_tensor[:, q, :] * tl.cosh(yk),
162167
)
163168

164169
# Update transform_P
165-
pvec = tl.copy(transform_P[:, p])
170+
prev_vec = tl.copy(transform_P[:, p])
166171
transform_P = tl.index_update(
167172
transform_P,
168173
tl.index[:, p],
@@ -171,11 +176,11 @@ def joint_matrix_diagonalization(
171176
transform_P = tl.index_update(
172177
transform_P,
173178
tl.index[:, q],
174-
pvec * tl.sinh(yk) + transform_P[:, q] * tl.cosh(yk),
179+
prev_vec * tl.sinh(yk) + transform_P[:, q] * tl.cosh(yk),
175180
)
176181

177182
# Defines array of off-diagonal element differences
178-
xi_ = -X[q, p, :] - X[p, q, :]
183+
xi_ = -matrices_tensor[q, p, :] - matrices_tensor[p, q, :]
179184

180185
# Compute rotation angle
181186
Esum = 2 * tl.dot(xi_, d_)
@@ -194,33 +199,35 @@ def joint_matrix_diagonalization(
194199
raise RuntimeError("joint_matrix_diagonalization: No solution found.")
195200

196201
# Given's rotation, this will minimize norm of off-diagonal elements only
197-
pvec = tl.copy(X[p, :, :])
198-
X = tl.index_update(
199-
X,
202+
prev_vec = tl.copy(matrices_tensor[p, :, :])
203+
matrices_tensor = tl.index_update(
204+
matrices_tensor,
200205
tl.index[p, :, :],
201-
X[p, :, :] * tl.cos(theta_k) - X[q, :, :] * tl.sin(theta_k),
206+
matrices_tensor[p, :, :] * tl.cos(theta_k)
207+
- matrices_tensor[q, :, :] * tl.sin(theta_k),
202208
)
203-
X = tl.index_update(
204-
X,
209+
matrices_tensor = tl.index_update(
210+
matrices_tensor,
205211
tl.index[q, :, :],
206-
pvec * tl.sin(theta_k) + X[q, :, :] * tl.cos(theta_k),
212+
prev_vec * tl.sin(theta_k) + matrices_tensor[q, :, :] * tl.cos(theta_k),
207213
)
208214

209215
# Right side rotation
210-
pvec = tl.copy(X[:, p, :])
211-
X = tl.index_update(
212-
X,
216+
prev_vec = tl.copy(matrices_tensor[:, p, :])
217+
matrices_tensor = tl.index_update(
218+
matrices_tensor,
213219
tl.index[:, p, :],
214-
X[:, p, :] * tl.cos(theta_k) - X[:, q, :] * tl.sin(theta_k),
220+
matrices_tensor[:, p, :] * tl.cos(theta_k)
221+
- matrices_tensor[:, q, :] * tl.sin(theta_k),
215222
)
216-
X = tl.index_update(
217-
X,
223+
matrices_tensor = tl.index_update(
224+
matrices_tensor,
218225
tl.index[:, q, :],
219-
pvec * tl.sin(theta_k) + X[:, q, :] * tl.cos(theta_k),
226+
prev_vec * tl.sin(theta_k) + matrices_tensor[:, q, :] * tl.cos(theta_k),
220227
)
221228

222229
# Update transform_P
223-
pvec = tl.copy(transform_P[:, p])
230+
prev_vec = tl.copy(transform_P[:, p])
224231
transform_P = tl.index_update(
225232
transform_P,
226233
tl.index[:, p],
@@ -230,7 +237,7 @@ def joint_matrix_diagonalization(
230237
transform_P = tl.index_update(
231238
transform_P,
232239
tl.index[:, q],
233-
pvec * tl.sin(theta_k) + transform_P[:, q] * tl.cos(theta_k),
240+
prev_vec * tl.sin(theta_k) + transform_P[:, q] * tl.cos(theta_k),
234241
)
235242

236243
# Update deviation from normality

0 commit comments

Comments
 (0)