diff --git a/ctgan/data_transformer.py b/ctgan/data_transformer.py index c1e136b5..b78b2aca 100644 --- a/ctgan/data_transformer.py +++ b/ctgan/data_transformer.py @@ -189,7 +189,9 @@ def transform(self, raw_data): def _inverse_transform_continuous(self, column_transform_info, column_data, sigmas, st): gm = column_transform_info.transform - data = pd.DataFrame(column_data[:, :2], columns=list(gm.get_output_sdtypes())) + cols_names = list(gm.get_output_sdtypes()) + n_cols = len(cols_names) + data = pd.DataFrame(column_data[:, :n_cols], columns=cols_names) data[data.columns[1]] = np.argmax(column_data[:, 1:], axis=1) if sigmas is not None: selected_normalized_value = np.random.normal(data.iloc[:, 0], sigmas[st])