@@ -76,7 +76,7 @@ def joint_matrix_diagonalization(
7676 quadratic convergence rate.
7777
7878 Args:
79- X (_type_): n matrices, organized in a single tensor of dimension (k, k, n).
79+ matrices_tensor (_type_): n matrices, organized in a single tensor of dimension (k, k, n).
8080 max_n_iter (int, optional): Maximum iteration number. Defaults to 50.
8181 threshold (float, optional): Threshold for decrease in deviation indicating convergence. Defaults to 1e-8.
8282 verbose (bool, optional): Output progress information during diagonalization. Defaults to False.
@@ -139,30 +139,35 @@ def joint_matrix_diagonalization(
139139 )
140140
141141 # Inverse of Sk on left side
142- pvec = tl .copy (matrices_tensor [p , :, :])
143- X = tl .index_update (
142+ prev_vec = tl .copy (matrices_tensor [p , :, :])
143+ matrices_tensor = tl .index_update (
144144 matrices_tensor ,
145145 tl .index [p , :, :],
146146 matrices_tensor [p , :, :] * tl .cosh (yk )
147147 - matrices_tensor [q , :, :] * tl .sinh (yk ),
148148 )
149- X = tl .index_update (
150- X , tl .index [q , :, :], - pvec * tl .sinh (yk ) + X [q , :, :] * tl .cosh (yk )
149+ matrices_tensor = tl .index_update (
150+ matrices_tensor ,
151+ tl .index [q , :, :],
152+ - prev_vec * tl .sinh (yk ) + matrices_tensor [q , :, :] * tl .cosh (yk ),
151153 )
152154
153155 # Sk on right side
154- pvec = tl .copy (X [:, p , :])
155- X = tl .index_update (
156- X ,
156+ prev_vec = tl .copy (matrices_tensor [:, p , :])
157+ matrices_tensor = tl .index_update (
158+ matrices_tensor ,
157159 tl .index [:, p , :],
158- X [:, p , :] * tl .cosh (yk ) + X [:, q , :] * tl .sinh (yk ),
160+ matrices_tensor [:, p , :] * tl .cosh (yk )
161+ + matrices_tensor [:, q , :] * tl .sinh (yk ),
159162 )
160- X = tl .index_update (
161- X , tl .index [:, q , :], pvec * tl .sinh (yk ) + X [:, q , :] * tl .cosh (yk )
163+ matrices_tensor = tl .index_update (
164+ matrices_tensor ,
165+ tl .index [:, q , :],
166+ prev_vec * tl .sinh (yk ) + matrices_tensor [:, q , :] * tl .cosh (yk ),
162167 )
163168
164169 # Update transform_P
165- pvec = tl .copy (transform_P [:, p ])
170+ prev_vec = tl .copy (transform_P [:, p ])
166171 transform_P = tl .index_update (
167172 transform_P ,
168173 tl .index [:, p ],
@@ -171,11 +176,11 @@ def joint_matrix_diagonalization(
171176 transform_P = tl .index_update (
172177 transform_P ,
173178 tl .index [:, q ],
174- pvec * tl .sinh (yk ) + transform_P [:, q ] * tl .cosh (yk ),
179+ prev_vec * tl .sinh (yk ) + transform_P [:, q ] * tl .cosh (yk ),
175180 )
176181
177182 # Defines array of off-diagonal element differences
178- xi_ = - X [q , p , :] - X [p , q , :]
183+ xi_ = - matrices_tensor [q , p , :] - matrices_tensor [p , q , :]
179184
180185 # Compute rotation angle
181186 Esum = 2 * tl .dot (xi_ , d_ )
@@ -194,33 +199,35 @@ def joint_matrix_diagonalization(
194199 raise RuntimeError ("joint_matrix_diagonalization: No solution found." )
195200
196201 # Given's rotation, this will minimize norm of off-diagonal elements only
197- pvec = tl .copy (X [p , :, :])
198- X = tl .index_update (
199- X ,
202+ prev_vec = tl .copy (matrices_tensor [p , :, :])
203+ matrices_tensor = tl .index_update (
204+ matrices_tensor ,
200205 tl .index [p , :, :],
201- X [p , :, :] * tl .cos (theta_k ) - X [q , :, :] * tl .sin (theta_k ),
206+ matrices_tensor [p , :, :] * tl .cos (theta_k )
207+ - matrices_tensor [q , :, :] * tl .sin (theta_k ),
202208 )
203- X = tl .index_update (
204- X ,
209+ matrices_tensor = tl .index_update (
210+ matrices_tensor ,
205211 tl .index [q , :, :],
206- pvec * tl .sin (theta_k ) + X [q , :, :] * tl .cos (theta_k ),
212+ prev_vec * tl .sin (theta_k ) + matrices_tensor [q , :, :] * tl .cos (theta_k ),
207213 )
208214
209215 # Right side rotation
210- pvec = tl .copy (X [:, p , :])
211- X = tl .index_update (
212- X ,
216+ prev_vec = tl .copy (matrices_tensor [:, p , :])
217+ matrices_tensor = tl .index_update (
218+ matrices_tensor ,
213219 tl .index [:, p , :],
214- X [:, p , :] * tl .cos (theta_k ) - X [:, q , :] * tl .sin (theta_k ),
220+ matrices_tensor [:, p , :] * tl .cos (theta_k )
221+ - matrices_tensor [:, q , :] * tl .sin (theta_k ),
215222 )
216- X = tl .index_update (
217- X ,
223+ matrices_tensor = tl .index_update (
224+ matrices_tensor ,
218225 tl .index [:, q , :],
219- pvec * tl .sin (theta_k ) + X [:, q , :] * tl .cos (theta_k ),
226+ prev_vec * tl .sin (theta_k ) + matrices_tensor [:, q , :] * tl .cos (theta_k ),
220227 )
221228
222229 # Update transform_P
223- pvec = tl .copy (transform_P [:, p ])
230+ prev_vec = tl .copy (transform_P [:, p ])
224231 transform_P = tl .index_update (
225232 transform_P ,
226233 tl .index [:, p ],
@@ -230,7 +237,7 @@ def joint_matrix_diagonalization(
230237 transform_P = tl .index_update (
231238 transform_P ,
232239 tl .index [:, q ],
233- pvec * tl .sin (theta_k ) + transform_P [:, q ] * tl .cos (theta_k ),
240+ prev_vec * tl .sin (theta_k ) + transform_P [:, q ] * tl .cos (theta_k ),
234241 )
235242
236243 # Update deviation from normality
0 commit comments