diff --git a/internal/oraclehandlers/patching.go b/internal/oraclehandlers/patching.go index f843fe8..27cb7ef 100644 --- a/internal/oraclehandlers/patching.go +++ b/internal/oraclehandlers/patching.go @@ -18,21 +18,160 @@ package oraclehandlers import ( "context" + "fmt" + "os" + "path/filepath" + "strings" + "github.com/GoogleCloudPlatform/workloadagentplatform/sharedlibraries/commandlineexecutor" "github.com/GoogleCloudPlatform/workloadagentplatform/sharedlibraries/gce/metadataserver" "github.com/GoogleCloudPlatform/workloadagentplatform/sharedlibraries/log" + + codepb "google.golang.org/genproto/googleapis/rpc/code" gpb "github.com/GoogleCloudPlatform/workloadagentplatform/sharedprotos/guestactions" ) +var ( + osStat = os.Stat + osReadFile = os.ReadFile + osWriteFile = os.WriteFile +) + +type startupMechanism int + +const ( + startupUnknown startupMechanism = iota + startupOracleRestart + startupOratab + startupSystemdFree +) + // DisableAutostart implements the oracle_disable_autostart guest action. func DisableAutostart(ctx context.Context, command *gpb.Command, cloudProperties *metadataserver.CloudProperties) *gpb.CommandResult { - log.CtxLogger(ctx).Info("oracle_disable_autostart handler called") - // TODO: Implement oracle_disable_autostart handler. - return &gpb.CommandResult{ - Command: command, - ExitCode: 1, - Stdout: "oracle_disable_autostart not implemented.", + params := command.GetAgentCommand().GetParameters() + logger := log.CtxLogger(ctx) + if result := validateParams(ctx, logger, command, params); result != nil { + return result + } + logger = logger.With("oracle_sid", params["oracle_sid"], "oracle_home", params["oracle_home"], "oracle_user", params["oracle_user"]) + logger.Info("oracle_disable_autostart handler called") + + if err := disableAutostart(ctx, params); err != nil { + logger.Warnw("DisableAutostart failed", "error", err) + return commandResult(ctx, logger, command, "", "", codepb.Code_INTERNAL, err.Error(), err) + } + + return commandResult(ctx, logger, command, "Autostart disabled successfully", "", codepb.Code_OK, "Autostart disabled successfully", nil) +} + +func disableAutostart(ctx context.Context, params map[string]string) error { + oracleSID := params["oracle_sid"] + oracleHome := params["oracle_home"] + oracleUser := params["oracle_user"] + dbUniqueName := params["db_unique_name"] + + mechanism, err := detectStartupMechanism(ctx) + if err != nil { + return err + } + + switch mechanism { + case startupOracleRestart: + srvctlPath := filepath.Join(oracleHome, "bin", "srvctl") + disableRes := executeCommand(ctx, commandlineexecutor.Params{ + Executable: srvctlPath, + Args: []string{"disable", "database", "-d", dbUniqueName}, + User: oracleUser, + Env: []string{"ORACLE_HOME=" + oracleHome, "ORACLE_SID=" + oracleSID, "LD_LIBRARY_PATH=" + filepath.Join(oracleHome, "lib")}, + }) + if disableRes.ExitCode != 0 { + return fmt.Errorf("failed to disable database via srvctl: %s", disableRes.StdErr) + } + case startupOratab: + if err := setAutostartInOratab("/etc/oratab", oracleSID, false); err != nil { + return fmt.Errorf("failed to disable autostart in /etc/oratab: %w", err) + } + case startupSystemdFree: + serviceName, err := getOracleFreeSystemdServiceName(ctx) + if err != nil { + return fmt.Errorf("failed to get oracle-free service name: %w", err) + } + res := executeCommand(ctx, commandlineexecutor.Params{ + Executable: "systemctl", + Args: []string{"disable", serviceName}, + }) + if res.ExitCode != 0 { + return fmt.Errorf("failed to disable service %s: %s", serviceName, res.StdErr) + } + default: + return fmt.Errorf("unknown startup mechanism") } + + return nil +} + +// EnableAutostart implements the oracle_enable_autostart guest action. +func EnableAutostart(ctx context.Context, command *gpb.Command, cloudProperties *metadataserver.CloudProperties) *gpb.CommandResult { + params := command.GetAgentCommand().GetParameters() + logger := log.CtxLogger(ctx) + if result := validateParams(ctx, logger, command, params); result != nil { + return result + } + logger = logger.With("oracle_sid", params["oracle_sid"], "oracle_home", params["oracle_home"], "oracle_user", params["oracle_user"]) + logger.Info("oracle_enable_autostart handler called") + + if err := enableAutostart(ctx, params); err != nil { + logger.Warnw("EnableAutostart failed", "error", err) + return commandResult(ctx, logger, command, "", "", codepb.Code_INTERNAL, err.Error(), err) + } + + return commandResult(ctx, logger, command, "Autostart enabled successfully", "", codepb.Code_OK, "Autostart enabled successfully", nil) +} + +func enableAutostart(ctx context.Context, params map[string]string) error { + oracleSID := params["oracle_sid"] + oracleHome := params["oracle_home"] + oracleUser := params["oracle_user"] + dbUniqueName := params["db_unique_name"] + + state, err := detectStartupMechanism(ctx) + if err != nil { + return err + } + + switch state { + case startupOracleRestart: + srvctlPath := filepath.Join(oracleHome, "bin", "srvctl") + res := executeCommand(ctx, commandlineexecutor.Params{ + Executable: srvctlPath, + Args: []string{"enable", "database", "-d", dbUniqueName}, + User: oracleUser, + Env: []string{"ORACLE_HOME=" + oracleHome, "ORACLE_SID=" + oracleSID, "LD_LIBRARY_PATH=" + filepath.Join(oracleHome, "lib")}, + }) + if res.ExitCode != 0 { + return fmt.Errorf("failed to enable database via srvctl: %s", res.StdErr) + } + case startupOratab: + if err := setAutostartInOratab("/etc/oratab", oracleSID, true); err != nil { + return fmt.Errorf("failed to enable autostart in /etc/oratab: %w", err) + } + case startupSystemdFree: + serviceName, err := getOracleFreeSystemdServiceName(ctx) + if err != nil { + return fmt.Errorf("failed to get oracle-free service name: %w", err) + } + res := executeCommand(ctx, commandlineexecutor.Params{ + Executable: "systemctl", + Args: []string{"enable", serviceName}, + }) + if res.ExitCode != 0 { + return fmt.Errorf("failed to enable service %s: %s", serviceName, res.StdErr) + } + default: + return fmt.Errorf("unknown autostart state: %d", state) + } + + return nil } // RunDatapatch implements the oracle_run_datapatch guest action. @@ -68,13 +207,131 @@ func StartListener(ctx context.Context, command *gpb.Command, cloudProperties *m } } -// EnableAutostart implements the oracle_enable_autostart guest action. -func EnableAutostart(ctx context.Context, command *gpb.Command, cloudProperties *metadataserver.CloudProperties) *gpb.CommandResult { - log.CtxLogger(ctx).Info("oracle_enable_autostart handler called") - // TODO: Implement oracle_enable_autostart handler. - return &gpb.CommandResult{ - Command: command, - ExitCode: 1, - Stdout: "oracle_enable_autostart not implemented.", +func detectStartupMechanism(ctx context.Context) (startupMechanism, error) { + // Check for ASM (Oracle Restart) + // ASM implies the presence of Oracle Grid Infrastructure. + // In this configuration, 'Oracle Restart' (part of GI) manages the database lifecycle. + // It ignores the autostart flags in /etc/oratab, relying instead on its own internal registry. + // We detect this by looking for the unique ASM Process Monitor (pmon) process. + pgrepRes := executeCommand(ctx, commandlineexecutor.Params{ + Executable: "pgrep", + Args: []string{"-f", "asm_pmon_+ASM"}, + }) + if pgrepRes.ExitCode == 0 { + return startupOracleRestart, nil + } + + // Check for Filesystem configuration. + // For Filesystem deployments, the oracle-toolkit installs a custom systemd + // service named 'dbora'. + // This service executes a helper script that reads /etc/oratab and starts any instances + // explicitly marked with a 'Y' flag. + sysCtlRes := executeCommand(ctx, commandlineexecutor.Params{ + Executable: "systemctl", + Args: []string{"is-active", "--quiet", "dbora.service"}, + }) + if sysCtlRes.ExitCode == 0 { + return startupOratab, nil + } + + // Check for Oracle Free Edition + // Oracle Database Free Edition packages provide their own native systemd service + // (e.g., 'oracle-free-23c.service') and do not use the toolkit's 'dbora' service. + if _, err := getOracleFreeSystemdServiceName(ctx); err == nil { + return startupSystemdFree, nil + } + + return startupUnknown, fmt.Errorf("unable to detect startup mechanism") +} + +func getOracleFreeSystemdServiceName(ctx context.Context) (string, error) { + listUnitsRes := executeCommand(ctx, commandlineexecutor.Params{ + Executable: "systemctl", + Args: []string{"list-units", "--all", "--plain", "--no-legend", "oracle-free*.service"}, + }) + if listUnitsRes.ExitCode != 0 { + return "", fmt.Errorf("failed to list oracle-free services: %s", listUnitsRes.StdErr) + } + output := strings.TrimSpace(listUnitsRes.StdOut) + if len(output) == 0 { + return "", fmt.Errorf("no oracle-free service found") } + // Take the first one found. + fields := strings.Fields(output) + if len(fields) > 0 { + return fields[0], nil + } + return "", fmt.Errorf("failed to parse systemctl output") +} + +// setAutostartInOratab updates the oratab file to set the autostart flag for the given SID. +func setAutostartInOratab(filePath string, targetSID string, enable bool) error { + content, err := osReadFile(filePath) + if err != nil { + return err + } + + lines := strings.Split(string(content), "\n") + var outputLines []string + + newValue := "N" + if enable { + newValue = "Y" + } + + for _, line := range lines { + trimmed := strings.TrimSpace(line) + if trimmed == "" || strings.HasPrefix(trimmed, "#") { + outputLines = append(outputLines, line) + continue + } + + // Format is $ORACLE_SID:$ORACLE_HOME: + parts := strings.Split(line, ":") + if len(parts) >= 3 && parts[0] == targetSID { + parts[2] = newValue + outputLines = append(outputLines, strings.Join(parts, ":")) + } else { + outputLines = append(outputLines, line) + } + } + + output := strings.Join(outputLines, "\n") + info, err := osStat(filePath) + if err != nil { + return err + } + + return osWriteFile(filePath, []byte(output), info.Mode()) +} + +// isAutostartEnabledInOratab parses the oratab file to see if the given SID is set to 'Y' +func isAutostartEnabledInOratab(filePath string, targetSID string) (bool, error) { + content, err := osReadFile(filePath) + if err != nil { + return false, err + } + + lines := strings.Split(string(content), "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + + if line == "" || strings.HasPrefix(line, "#") { + continue + } + + // Format is $ORACLE_SID:$ORACLE_HOME: + parts := strings.Split(line, ":") + + if len(parts) >= 3 { + currentSID := parts[0] + autoStartFlag := parts[2] + + if currentSID == targetSID { + return autoStartFlag == "Y", nil + } + } + } + + return false, nil } diff --git a/internal/oraclehandlers/patching_test.go b/internal/oraclehandlers/patching_test.go index 7e1e031..d5b2f2d 100644 --- a/internal/oraclehandlers/patching_test.go +++ b/internal/oraclehandlers/patching_test.go @@ -18,26 +18,636 @@ package oraclehandlers import ( "context" + "errors" + "os" "strings" "testing" + "github.com/GoogleCloudPlatform/workloadagentplatform/sharedlibraries/commandlineexecutor" gpb "github.com/GoogleCloudPlatform/workloadagentplatform/sharedprotos/guestactions" ) -func TestDisableAutostart_NotImplemented(t *testing.T) { - command := &gpb.Command{ - CommandType: &gpb.Command_AgentCommand{ - AgentCommand: &gpb.AgentCommand{ - Command: "oracle_disable_autostart", +func TestSetAutostartInOratab(t *testing.T) { + tests := []struct { + name string + initialFile string + targetSID string + enable bool + wantFile string + mockReadErr error + mockStatErr error + mockWriteErr error + wantErr bool + }{ + { + name: "Enable", + initialFile: "ORCL:/u01:N\n", + targetSID: "ORCL", + enable: true, + wantFile: "ORCL:/u01:Y\n", + wantErr: false, + }, + { + name: "Disable", + initialFile: "ORCL:/u01:Y\n", + targetSID: "ORCL", + enable: false, + wantFile: "ORCL:/u01:N\n", + wantErr: false, + }, + { + name: "NoChange", + initialFile: "ORCL:/u01:Y\n", + targetSID: "ORCL", + enable: true, + wantFile: "ORCL:/u01:Y\n", + wantErr: false, + }, + { + name: "SIDNotFound", + initialFile: "OTHER:/u01:N\n", + targetSID: "ORCL", + enable: true, + wantFile: "OTHER:/u01:N\n", + wantErr: false, + }, + { + name: "CommentsPreserved", + initialFile: "# Header\n" + + "ORCL:/u01:N\n", + targetSID: "ORCL", + enable: true, + wantFile: "# Header\n" + + "ORCL:/u01:Y\n", + wantErr: false, + }, + { + name: "ReadError", + initialFile: "", + mockReadErr: errors.New("read error"), + wantErr: true, + }, + { + name: "StatError", + initialFile: "ORCL:/u01:N\n", + mockStatErr: errors.New("stat error"), + wantErr: true, + }, + { + name: "WriteError", + initialFile: "ORCL:/u01:N\n", + mockWriteErr: errors.New("write error"), + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Mock osReadFile + oldOsReadFile := osReadFile + defer func() { osReadFile = oldOsReadFile }() + osReadFile = func(name string) ([]byte, error) { + if tt.mockReadErr != nil { + return nil, tt.mockReadErr + } + return []byte(tt.initialFile), nil + } + + // Mock osStat + oldOsStat := osStat + defer func() { osStat = oldOsStat }() + osStat = func(name string) (os.FileInfo, error) { + if tt.mockStatErr != nil { + return nil, tt.mockStatErr + } + // Create a temporary file to get a valid FileInfo + tmpFile, err := os.CreateTemp("", "mock_oratab") + if err != nil { + t.Fatalf("Failed to create temp file for mock stat: %v", err) + } + defer os.Remove(tmpFile.Name()) + return tmpFile.Stat() + } + + // Mock osWriteFile + oldOsWriteFile := osWriteFile + defer func() { osWriteFile = oldOsWriteFile }() + var capturedWrite []byte + osWriteFile = func(name string, data []byte, perm os.FileMode) error { + if tt.mockWriteErr != nil { + return tt.mockWriteErr + } + capturedWrite = data + return nil + } + + err := setAutostartInOratab("/etc/oratab", tt.targetSID, tt.enable) + if (err != nil) != tt.wantErr { + t.Errorf("setAutostartInOratab() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr && string(capturedWrite) != tt.wantFile { + t.Errorf("setAutostartInOratab() wrote %q, want %q", string(capturedWrite), tt.wantFile) + } + }) + } +} + +func TestIsAutostartEnabledInOratab(t *testing.T) { + tests := []struct { + name string + fileContent string + targetSID string + readFileErr error + want bool + wantErr bool + }{ + { + name: "Enabled", + fileContent: "ORCL:/u01/app/oracle/product/19.0.0/dbhome_1:Y\n", + targetSID: "ORCL", + want: true, + wantErr: false, + }, + { + name: "Disabled", + fileContent: "ORCL:/u01/app/oracle/product/19.0.0/dbhome_1:N\n", + targetSID: "ORCL", + want: false, + wantErr: false, + }, + { + name: "NotFound", + fileContent: "OTHER:/u01/app/oracle/product/19.0.0/dbhome_1:Y\n", + targetSID: "ORCL", + want: false, + wantErr: false, + }, + { + name: "CommentsIgnored", + fileContent: "# This is a comment\n" + + "ORCL:/u01/app/oracle/product/19.0.0/dbhome_1:Y\n", + targetSID: "ORCL", + want: true, + wantErr: false, + }, + { + name: "ReadError", + fileContent: "", + targetSID: "ORCL", + readFileErr: errors.New("read error"), + want: false, + wantErr: true, + }, + { + name: "MalformedLine", + fileContent: "ORCL:/u01/app/oracle/product/19.0.0/dbhome_1\n", // Missing flag + targetSID: "ORCL", + want: false, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + oldOsReadFile := osReadFile + defer func() { osReadFile = oldOsReadFile }() + osReadFile = func(name string) ([]byte, error) { + if tt.readFileErr != nil { + return nil, tt.readFileErr + } + return []byte(tt.fileContent), nil + } + + got, err := isAutostartEnabledInOratab("/etc/oratab", tt.targetSID) + if (err != nil) != tt.wantErr { + t.Errorf("isAutostartEnabledInOratab() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("isAutostartEnabledInOratab() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestGetOracleFreeSystemdServiceName(t *testing.T) { + tests := []struct { + name string + mockOutput string + mockExit int + mockError string + want string + wantErr bool + }{ + { + name: "ServiceFound", + mockOutput: "oracle-free-23c.service loaded active running Oracle Database Free 23c", + mockExit: 0, + want: "oracle-free-23c.service", + wantErr: false, + }, + { + name: "NoServiceFound", + mockOutput: "", + mockExit: 0, + want: "", + wantErr: true, + }, + { + name: "CommandFailed", + mockOutput: "", + mockExit: 1, + mockError: "command failed", + want: "", + wantErr: true, + }, + { + name: "MultipleServicesFirstTaken", + mockOutput: "oracle-free-23c.service loaded active running\noracle-free-21c.service loaded active running", + mockExit: 0, + want: "oracle-free-23c.service", + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + oldExecuteCommand := executeCommand + defer func() { executeCommand = oldExecuteCommand }() + executeCommand = func(ctx context.Context, params commandlineexecutor.Params) commandlineexecutor.Result { + return commandlineexecutor.Result{ + StdOut: tt.mockOutput, + StdErr: tt.mockError, + ExitCode: tt.mockExit, + } + } + + got, err := getOracleFreeSystemdServiceName(context.Background()) + if (err != nil) != tt.wantErr { + t.Errorf("getOracleFreeSystemdServiceName() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("getOracleFreeSystemdServiceName() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestEnableAutostart(t *testing.T) { + defaultParams := map[string]string{ + "oracle_sid": "ORCL", + "oracle_home": "/u01/app/oracle/product/19.0.0/dbhome_1", + "oracle_user": "oracle", + "db_unique_name": "ORCL_SITE1", + } + + tests := []struct { + name string + mockCmds map[string]commandlineexecutor.Result + initialFile string + wantWriteContent string + wantErr bool + }{ + { + name: "OracleRestart_Success", + mockCmds: map[string]commandlineexecutor.Result{ + "pgrep": {ExitCode: 0}, + "srvctl": {ExitCode: 0}, }, + wantErr: false, + }, + { + name: "OracleRestart_Failure", + mockCmds: map[string]commandlineexecutor.Result{ + "pgrep": {ExitCode: 0}, + "srvctl": {ExitCode: 1, StdErr: "srvctl failed"}, + }, + wantErr: true, + }, + { + name: "Oratab_Success", + mockCmds: map[string]commandlineexecutor.Result{ + "pgrep": {ExitCode: 1}, + "dbora": {ExitCode: 0}, + }, + initialFile: "ORCL:/u01:N\n", + wantWriteContent: "ORCL:/u01:Y\n", + wantErr: false, + }, + { + name: "SystemdFree_Success", + mockCmds: map[string]commandlineexecutor.Result{ + "pgrep": {ExitCode: 1}, + "dbora": {ExitCode: 1}, + "oracle-free-list": {ExitCode: 0, StdOut: "oracle-free-23c.service"}, + "enable-service": {ExitCode: 0}, + }, + wantErr: false, + }, + { + name: "SystemdFree_Failure", + mockCmds: map[string]commandlineexecutor.Result{ + "pgrep": {ExitCode: 1}, + "dbora": {ExitCode: 1}, + "oracle-free-list": {ExitCode: 0, StdOut: "oracle-free-23c.service"}, + "enable-service": {ExitCode: 1, StdErr: "systemctl failed"}, + }, + wantErr: true, + }, + { + name: "UnknownStartupMechanism", + mockCmds: map[string]commandlineexecutor.Result{ + "pgrep": {ExitCode: 1}, + "dbora": {ExitCode: 1}, + "oracle-free-list": {ExitCode: 0, StdOut: ""}, + }, + wantErr: true, }, } - result := DisableAutostart(context.Background(), command, nil) - if result.GetExitCode() != 1 { - t.Errorf("DisableAutostart() returned exit code %d, want 1", result.GetExitCode()) + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Mock executeCommand + oldExecuteCommand := executeCommand + defer func() { executeCommand = oldExecuteCommand }() + executeCommand = func(ctx context.Context, params commandlineexecutor.Params) commandlineexecutor.Result { + cmd := params.Executable + args := strings.Join(params.Args, " ") + + if cmd == "pgrep" && strings.Contains(args, "asm_pmon_+ASM") { + return tt.mockCmds["pgrep"] + } + if strings.HasSuffix(cmd, "srvctl") { + return tt.mockCmds["srvctl"] + } + if cmd == "systemctl" { + if strings.Contains(args, "is-active") && strings.Contains(args, "dbora.service") { + return tt.mockCmds["dbora"] + } + if strings.Contains(args, "list-units") && strings.Contains(args, "oracle-free") { + return tt.mockCmds["oracle-free-list"] + } + if strings.Contains(args, "enable") { + return tt.mockCmds["enable-service"] + } + } + return commandlineexecutor.Result{ExitCode: 1, StdErr: "mock command not found: " + cmd + " " + args} + } + + // Mock file system for oratab + oldOsReadFile := osReadFile + defer func() { osReadFile = oldOsReadFile }() + osReadFile = func(name string) ([]byte, error) { + return []byte(tt.initialFile), nil + } + + oldOsStat := osStat + defer func() { osStat = oldOsStat }() + osStat = func(name string) (os.FileInfo, error) { + tmpFile, _ := os.CreateTemp("", "mock_oratab") + defer os.Remove(tmpFile.Name()) + return tmpFile.Stat() + } + + oldOsWriteFile := osWriteFile + defer func() { osWriteFile = oldOsWriteFile }() + var capturedWrite []byte + osWriteFile = func(name string, data []byte, perm os.FileMode) error { + capturedWrite = data + return nil + } + + err := enableAutostart(context.Background(), defaultParams) + if (err != nil) != tt.wantErr { + t.Errorf("enableAutostart() error = %v, wantErr %v", err, tt.wantErr) + } + + if tt.wantWriteContent != "" { + if string(capturedWrite) != tt.wantWriteContent { + t.Errorf("enableAutostart() wrote %q, want %q", string(capturedWrite), tt.wantWriteContent) + } + } + }) } - if !strings.Contains(result.GetStdout(), "not implemented") { - t.Errorf("DisableAutostart() returned stdout %q, want 'not implemented'", result.GetStdout()) +} + +func TestDisableAutostart(t *testing.T) { + defaultParams := map[string]string{ + "oracle_sid": "ORCL", + "oracle_home": "/u01/app/oracle/product/19.0.0/dbhome_1", + "oracle_user": "oracle", + "db_unique_name": "ORCL_SITE1", + } + + tests := []struct { + name string + mockCmds map[string]commandlineexecutor.Result + initialFile string + wantWriteContent string + wantErr bool + }{ + { + name: "OracleRestart_Success", + mockCmds: map[string]commandlineexecutor.Result{ + "pgrep": {ExitCode: 0}, // Found ASM + "srvctl": {ExitCode: 0}, // Disable success + }, + wantErr: false, + }, + { + name: "OracleRestart_Failure", + mockCmds: map[string]commandlineexecutor.Result{ + "pgrep": {ExitCode: 0}, + "srvctl": {ExitCode: 1, StdErr: "srvctl failed"}, + }, + wantErr: true, + }, + { + name: "Oratab_Success", + mockCmds: map[string]commandlineexecutor.Result{ + "pgrep": {ExitCode: 1}, // No ASM + "dbora": {ExitCode: 0}, // dbora active + }, + initialFile: "ORCL:/u01:Y\n", + wantWriteContent: "ORCL:/u01:N\n", + wantErr: false, + }, + { + name: "SystemdFree_Success", + mockCmds: map[string]commandlineexecutor.Result{ + "pgrep": {ExitCode: 1}, + "dbora": {ExitCode: 1}, + "oracle-free-list": {ExitCode: 0, StdOut: "oracle-free-23c.service"}, + "disable-service": {ExitCode: 0}, + }, + wantErr: false, + }, + { + name: "SystemdFree_Failure", + mockCmds: map[string]commandlineexecutor.Result{ + "pgrep": {ExitCode: 1}, + "dbora": {ExitCode: 1}, + "oracle-free-list": {ExitCode: 0, StdOut: "oracle-free-23c.service"}, + "disable-service": {ExitCode: 1, StdErr: "systemctl failed"}, + }, + wantErr: true, + }, + { + name: "UnknownStartupMechanism", + mockCmds: map[string]commandlineexecutor.Result{ + "pgrep": {ExitCode: 1}, + "dbora": {ExitCode: 1}, + "oracle-free-list": {ExitCode: 0, StdOut: ""}, + }, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Mock executeCommand + oldExecuteCommand := executeCommand + defer func() { executeCommand = oldExecuteCommand }() + executeCommand = func(ctx context.Context, params commandlineexecutor.Params) commandlineexecutor.Result { + cmd := params.Executable + args := strings.Join(params.Args, " ") + + if cmd == "pgrep" && strings.Contains(args, "asm_pmon_+ASM") { + return tt.mockCmds["pgrep"] + } + if strings.HasSuffix(cmd, "srvctl") { + return tt.mockCmds["srvctl"] + } + if cmd == "systemctl" { + if strings.Contains(args, "is-active") && strings.Contains(args, "dbora.service") { + return tt.mockCmds["dbora"] + } + if strings.Contains(args, "list-units") && strings.Contains(args, "oracle-free") { + return tt.mockCmds["oracle-free-list"] + } + if strings.Contains(args, "disable") { + return tt.mockCmds["disable-service"] + } + } + return commandlineexecutor.Result{ExitCode: 1, StdErr: "mock command not found: " + cmd + " " + args} + } + + // Mock file system for oratab + oldOsReadFile := osReadFile + defer func() { osReadFile = oldOsReadFile }() + osReadFile = func(name string) ([]byte, error) { + return []byte(tt.initialFile), nil + } + + oldOsStat := osStat + defer func() { osStat = oldOsStat }() + osStat = func(name string) (os.FileInfo, error) { + tmpFile, _ := os.CreateTemp("", "mock_oratab") + defer os.Remove(tmpFile.Name()) + return tmpFile.Stat() + } + + oldOsWriteFile := osWriteFile + defer func() { osWriteFile = oldOsWriteFile }() + var capturedWrite []byte + osWriteFile = func(name string, data []byte, perm os.FileMode) error { + capturedWrite = data + return nil + } + + err := disableAutostart(context.Background(), defaultParams) + if (err != nil) != tt.wantErr { + t.Errorf("disableAutostart() error = %v, wantErr %v", err, tt.wantErr) + } + + if tt.wantWriteContent != "" { + if string(capturedWrite) != tt.wantWriteContent { + t.Errorf("disableAutostart() wrote %q, want %q", string(capturedWrite), tt.wantWriteContent) + } + } + }) + } +} + +func TestDetectStartupMechanism(t *testing.T) { + tests := []struct { + name string + mockCmds map[string]commandlineexecutor.Result + want startupMechanism + wantErr bool + }{ + { + name: "OracleRestart_ASM", + mockCmds: map[string]commandlineexecutor.Result{ + "pgrep": {ExitCode: 0}, + }, + want: startupOracleRestart, + wantErr: false, + }, + { + name: "Oratab_Dbora", + mockCmds: map[string]commandlineexecutor.Result{ + "pgrep": {ExitCode: 1}, + "dbora": {ExitCode: 0}, + }, + want: startupOratab, + wantErr: false, + }, + { + name: "SystemdFree", + mockCmds: map[string]commandlineexecutor.Result{ + "pgrep": {ExitCode: 1}, + "dbora": {ExitCode: 1}, + "oracle-free": {ExitCode: 0, StdOut: "oracle-free-23c.service"}, + }, + want: startupSystemdFree, + wantErr: false, + }, + { + name: "Unknown", + mockCmds: map[string]commandlineexecutor.Result{ + "pgrep": {ExitCode: 1}, + "dbora": {ExitCode: 1}, + "oracle-free": {ExitCode: 0, StdOut: ""}, + }, + want: startupUnknown, + wantErr: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + oldExecuteCommand := executeCommand + defer func() { executeCommand = oldExecuteCommand }() + executeCommand = func(ctx context.Context, params commandlineexecutor.Params) commandlineexecutor.Result { + cmd := params.Executable + args := strings.Join(params.Args, " ") + + if cmd == "pgrep" && strings.Contains(args, "asm_pmon_+ASM") { + return tt.mockCmds["pgrep"] + } + if cmd == "systemctl" { + if strings.Contains(args, "is-active") && strings.Contains(args, "dbora.service") { + return tt.mockCmds["dbora"] + } + if strings.Contains(args, "list-units") && strings.Contains(args, "oracle-free") { + return tt.mockCmds["oracle-free"] + } + } + return commandlineexecutor.Result{ExitCode: 1, StdErr: "mock command not found"} + } + + got, err := detectStartupMechanism(context.Background()) + if (err != nil) != tt.wantErr { + t.Errorf("detectStartupMechanism() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("detectStartupMechanism() = %v, want %v", got, tt.want) + } + }) } } @@ -91,20 +701,3 @@ func TestStartListener_NotImplemented(t *testing.T) { t.Errorf("StartListener() returned stdout %q, want 'not implemented'", result.GetStdout()) } } - -func TestEnableAutostart_NotImplemented(t *testing.T) { - command := &gpb.Command{ - CommandType: &gpb.Command_AgentCommand{ - AgentCommand: &gpb.AgentCommand{ - Command: "oracle_enable_autostart", - }, - }, - } - result := EnableAutostart(context.Background(), command, nil) - if result.GetExitCode() != 1 { - t.Errorf("EnableAutostart() returned exit code %d, want 1", result.GetExitCode()) - } - if !strings.Contains(result.GetStdout(), "not implemented") { - t.Errorf("EnableAutostart() returned stdout %q, want 'not implemented'", result.GetStdout()) - } -}