diff --git a/go.mod b/go.mod index d4043c5..2c47a49 100644 --- a/go.mod +++ b/go.mod @@ -21,6 +21,7 @@ require ( github.com/jarcoal/httpmock v1.4.0 github.com/nebius/gosdk v0.0.0-20250826102719-940ad1dfb5de github.com/pkg/errors v0.9.1 + github.com/sfcompute/nodes-go v0.1.0-alpha.4 github.com/stretchr/testify v1.11.1 golang.org/x/crypto v0.42.0 golang.org/x/text v0.29.0 @@ -83,6 +84,10 @@ require ( github.com/sirupsen/logrus v1.9.3 // indirect github.com/spf13/afero v1.15.0 // indirect github.com/spf13/pflag v1.0.10 // indirect + github.com/tidwall/gjson v1.18.0 // indirect + github.com/tidwall/match v1.1.1 // indirect + github.com/tidwall/pretty v1.2.1 // indirect + github.com/tidwall/sjson v1.2.5 // indirect github.com/x448/float16 v0.8.4 // indirect go.yaml.in/yaml/v2 v2.4.3 // indirect go.yaml.in/yaml/v3 v3.0.4 // indirect diff --git a/go.sum b/go.sum index 1f70d3a..7727c30 100644 --- a/go.sum +++ b/go.sum @@ -160,6 +160,9 @@ github.com/prometheus/procfs v0.17.0 h1:FuLQ+05u4ZI+SS/w9+BWEM2TXiHKsUQ9TADiRH7D github.com/prometheus/procfs v0.17.0/go.mod h1:oPQLaDAMRbA+u8H5Pbfq+dl3VDAvHxMUOVhe0wYB2zw= github.com/rogpeppe/go-internal v1.13.1 h1:KvO1DLK/DRN07sQ1LQKScxyZJuNnedQ5/wKSR38lUII= github.com/rogpeppe/go-internal v1.13.1/go.mod h1:uMEvuHeurkdAXX61udpOXGD/AzZDWNMNyH2VO9fmH0o= +github.com/sfcompute/nodes-go v0.1.0-alpha.3/go.mod h1:dF3O8MCxLz3FTVYhjCa876Z9O3EAM8E8fONivDpfmkM= +github.com/sfcompute/nodes-go v0.1.0-alpha.4 h1:oFBWcMPSpqLYm/NDs5I1jTvzgx9rsXDL9Ghsm30Hc0Q= +github.com/sfcompute/nodes-go v0.1.0-alpha.4/go.mod h1:nUviHgK+Fgt2hDFcRL3M8VoyiypC8fc0dsY8C30QU8M= github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ= github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ= github.com/spf13/afero v1.15.0 h1:b/YBCLWAJdFWJTN9cLhiXXcD7mzKn9Dm86dNnfyQw1I= @@ -180,6 +183,16 @@ github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXl github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= +github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.18.0 h1:FIDeeyB800efLX89e5a8Y0BNH+LOngJyGrIWxG2FKQY= +github.com/tidwall/gjson v1.18.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4= +github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= +github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= github.com/yuin/goldmark v1.1.27/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9decYSb74= diff --git a/v1/providers/sfcompute/capabilities.go b/v1/providers/sfcompute/capabilities.go new file mode 100644 index 0000000..ac0604a --- /dev/null +++ b/v1/providers/sfcompute/capabilities.go @@ -0,0 +1,24 @@ +package v1 + +import ( + "context" + + v1 "github.com/brevdev/cloud/v1" +) + +func getSFCCapabilities() v1.Capabilities { + return v1.Capabilities{ + v1.CapabilityCreateInstance, + v1.CapabilityTerminateInstance, + v1.CapabilityCreateTerminateInstance, + // add others supported by your provider: reboot, stop/start, machine-image, tags, resize-volume, modify-firewall, etc. + } +} + +func (c *SFCClient) GetCapabilities(_ context.Context) (v1.Capabilities, error) { + return getSFCCapabilities(), nil +} + +func (c *SFCCredential) GetCapabilities(_ context.Context) (v1.Capabilities, error) { + return getSFCCapabilities(), nil +} diff --git a/v1/providers/sfcompute/client.go b/v1/providers/sfcompute/client.go new file mode 100644 index 0000000..fce08ca --- /dev/null +++ b/v1/providers/sfcompute/client.go @@ -0,0 +1,73 @@ +package v1 + +import ( + "context" + + v1 "github.com/brevdev/cloud/v1" + "github.com/sfcompute/nodes-go/option" + + sfcnodes "github.com/sfcompute/nodes-go" +) + +type SFCCredential struct { + RefID string + apiKey string `json:"api_key"` +} + +var _ v1.CloudCredential = &SFCCredential{} + +func NewSFCCredential(refID string, apiKey string /* auth fields */) *SFCCredential { + return &SFCCredential{ + RefID: refID, + apiKey: apiKey, + // ... + } +} + +func (c *SFCCredential) GetReferenceID() string { return c.RefID } +func (c *SFCCredential) GetAPIType() v1.APIType { return v1.APITypeLocational /* or v1.APITypeGlobal */ } +func (c *SFCCredential) GetCloudProviderID() v1.CloudProviderID { + return "sfcompute" // e.g., "lambdalabs" +} +func (c *SFCCredential) GetTenantID() (string, error) { + // sfc does not have a tenant system, return empty string + return "", nil +} + +func (c *SFCCredential) MakeClient(ctx context.Context, location string) (v1.CloudClient, error) { + // Create a client configured for a given location if locational API + return NewSFCClient(c.RefID, c.apiKey /* auth fields */).MakeClient(ctx, location) +} + +// ---------------- Client ---------------- + +type SFCClient struct { + v1.NotImplCloudClient + refID string + location string + apiKey string + client sfcnodes.Client // Add this field + // add http/sdk client fields, base URLs, etc. +} + +var _ v1.CloudClient = &SFCClient{} + +func NewSFCClient(refID string, apiKey string /* auth fields */) *SFCClient { + return &SFCClient{ + refID: refID, + apiKey: apiKey, + client: sfcnodes.NewClient( + option.WithBearerToken(apiKey)), + // init http/sdk clients here + } +} + +func (c *SFCClient) GetAPIType() v1.APIType { return v1.APITypeGlobal /* or Global */ } +func (c *SFCClient) GetCloudProviderID() v1.CloudProviderID { return "sfcompute" } +func (c *SFCClient) GetReferenceID() string { return c.refID } +func (c *SFCClient) GetTenantID() (string, error) { return "", nil } + +func (c *SFCClient) MakeClient(_ context.Context, location string) (v1.CloudClient, error) { + c.location = location + return c, nil +} diff --git a/v1/providers/sfcompute/instance.go b/v1/providers/sfcompute/instance.go new file mode 100644 index 0000000..4658eee --- /dev/null +++ b/v1/providers/sfcompute/instance.go @@ -0,0 +1,159 @@ +package v1 + +import ( + "context" + "encoding/base64" + "fmt" + "strings" + "time" + + v1 "github.com/brevdev/cloud/v1" + sfcnodes "github.com/sfcompute/nodes-go" + "github.com/sfcompute/nodes-go/packages/param" +) + +// define function to convert string to b64 +func toBase64(s string) string { + return base64.StdEncoding.EncodeToString([]byte(s)) +} + +// define function to add ssh key to cloud init +func sshKeyCloudInit(sshKey string) string { + return toBase64(fmt.Sprintf("#cloud-config\nssh_authorized_keys:\n - %s", sshKey)) +} + +func mapSFCStatus(s string) v1.LifecycleStatus { + switch strings.ToLower(s) { + case "pending", "nodefailure", "unspecified", "awaitingcapacity", "unknown", "failed": + return v1.LifecycleStatusPending + case "running": + return v1.LifecycleStatusRunning + // case "stopping": + //return v1.LifecycleStatusStopping + case "stopped": + return v1.LifecycleStatusStopped + case "terminating", "released": + return v1.LifecycleStatusTerminating + case "destroyed", "deleted": + return v1.LifecycleStatusTerminated + default: + return v1.LifecycleStatusPending + } +} + +func (c *SFCClient) CreateInstance(ctx context.Context, attrs v1.CreateInstanceAttrs) (*v1.Instance, error) { + resp, err := c.client.Nodes.New(ctx, sfcnodes.NodeNewParams{ + CreateNodesRequest: sfcnodes.CreateNodesRequestParam{ + DesiredCount: 1, + MaxPricePerNodeHour: 1600, + Zone: attrs.Location, + ImageID: param.Opt[string]{Value: attrs.ImageID}, //this needs to point to a valid image + CloudInitUserData: param.Opt[string]{Value: sshKeyCloudInit(attrs.PublicKey)}, // encode ssh key to b64-wrapped cloud-init script + }, + }) + if err != nil { + return nil, err + } + + if len(resp.Data) == 0 { + return nil, fmt.Errorf("no nodes returned") + } + node := resp.Data[0] + + inst := &v1.Instance{ + Name: attrs.Name, + RefID: attrs.RefID, + CloudCredRefID: c.refID, + CloudID: v1.CloudProviderInstanceID(node.ID), // SFC ID + ImageID: attrs.ImageID, + InstanceType: attrs.InstanceType, + Location: attrs.Location, + CreatedAt: time.Now(), + Status: v1.Status{LifecycleStatus: mapSFCStatus(fmt.Sprint(node.Status))}, // map SDK status to our lifecycle + InstanceTypeID: v1.InstanceTypeID(node.GPUType), + SSHPort: 2222, // we use 2222/tcp for all of our SSH ports + } + + return inst, nil +} + +func (c *SFCClient) GetInstance(ctx context.Context, id v1.CloudProviderInstanceID) (*v1.Instance, error) { + node, err := c.client.Nodes.Get(ctx, string(id)) + if err != nil { + panic(err.Error()) + } + var vmID string + if len(node.VMs.Data) > 0 { + vmID = node.VMs.Data[0].ID + fmt.Println(vmID) + } + + ssh, err := c.client.VMs.SSH(ctx, sfcnodes.VMSSHParams{VMID: vmID}) + if err != nil { + panic(err.Error()) + } + + inst := &v1.Instance{ + Name: node.Name, + RefID: c.refID, + CloudCredRefID: c.refID, + CloudID: v1.CloudProviderInstanceID(node.ID), // SFC ID + PublicIP: ssh.SSHHostname, + CreatedAt: time.Unix(node.CreatedAt, 0), + Status: v1.Status{LifecycleStatus: mapSFCStatus(fmt.Sprint(node.Status))}, // map SDK status to our lifecycle + InstanceTypeID: v1.InstanceTypeID(node.GPUType), + } + return inst, nil +} + +func (c *SFCClient) ListInstances(ctx context.Context, args v1.ListInstancesArgs) ([]v1.Instance, error) { + resp, err := c.client.Nodes.List(ctx, sfcnodes.NodeListParams{}) + if err != nil { + return nil, err + } + + var instances []v1.Instance + for _, node := range resp.Data { + inst, err := c.GetInstance(ctx, v1.CloudProviderInstanceID(node.ID)) + if err != nil { + return nil, err + } + if inst != nil { + instances = append(instances, *inst) + } + } + return instances, nil +} + +func (c *SFCClient) TerminateInstance(ctx context.Context, id v1.CloudProviderInstanceID) error { + // release the node first + _, errRelease := c.client.Nodes.Release(ctx, string(id)) + if errRelease != nil { + panic(errRelease.Error()) + } + // then delete the node + errDelete := c.client.Nodes.Delete(ctx, string(id)) + if errDelete != nil { + panic(errDelete.Error()) + } + return nil +} + +// Optional if supported: +func (c *SFCClient) RebootInstance(ctx context.Context, id v1.CloudProviderInstanceID) error { + return fmt.Errorf("not implemented") +} +func (c *SFCClient) StopInstance(ctx context.Context, id v1.CloudProviderInstanceID) error { + return fmt.Errorf("not implemented") +} +func (c *SFCClient) StartInstance(ctx context.Context, id v1.CloudProviderInstanceID) error { + return fmt.Errorf("not implemented") +} + +// Merge strategies (pass-through is acceptable baseline). +func (c *SFCClient) MergeInstanceForUpdate(_ v1.Instance, newInst v1.Instance) v1.Instance { + return newInst +} +func (c *SFCClient) MergeInstanceTypeForUpdate(_ v1.InstanceType, newIt v1.InstanceType) v1.InstanceType { + return newIt +} diff --git a/v1/providers/sfcompute/instancetype.go b/v1/providers/sfcompute/instancetype.go new file mode 100644 index 0000000..733a8b9 --- /dev/null +++ b/v1/providers/sfcompute/instancetype.go @@ -0,0 +1,101 @@ +package v1 + +import ( + "context" + "fmt" + "slices" + "strconv" + "time" + + "github.com/bojanz/currency" + "github.com/brevdev/cloud/internal/collections" + + v1 "github.com/brevdev/cloud/v1" +) + +func (c *SFCClient) getInstanceTypeID(region string) string { + return fmt.Sprintf("h100v_%v", region) +} + +func (c *SFCClient) GetInstanceTypes(ctx context.Context, args v1.GetInstanceTypeArgs) ([]v1.InstanceType, error) { + resp, err := c.client.Zones.List(ctx) + if err != nil { + return nil, err + } + + types := make([]v1.InstanceType, 0) + for _, zone := range resp.Data { + if len(args.Locations) > 0 && !args.Locations.IsAllowed(zone.Name) { + continue + } + var available = false + if len(zone.AvailableCapacity) > 0 && zone.DeliveryType == "VM" { + available = true + } + + price, _ := currency.NewAmount(strconv.Itoa(2), "USD") + types = append(types, v1.InstanceType{ + ID: v1.InstanceTypeID(c.getInstanceTypeID(zone.Name)), + IsAvailable: available, + Type: "h100v", + Location: zone.Name, + Stoppable: false, + Rebootable: false, + IsContainer: false, + BasePrice: &price, + EstimatedDeployTime: collections.Ptr(time.Duration(15 * time.Minute)), + SupportedGPUs: []v1.GPU{{ + Count: 8, + Type: "h100v", + Manufacturer: "nvidia", + Name: "h100v", + MemoryBytes: v1.NewBytes(80, v1.Gibibyte), + }}, + }) + + } + + if len(args.InstanceTypes) > 0 { + filteredTypes := make([]v1.InstanceType, 0) + for _, t := range types { + if slices.Contains(args.InstanceTypes, t.Type) { + filteredTypes = append(filteredTypes, t) + } + } + return filteredTypes, nil + } + + return types, nil +} + +func (c *SFCClient) GetLocations(ctx context.Context, _ v1.GetLocationsArgs) ([]v1.Location, error) { + resp, err := c.client.Zones.List(ctx) + if err != nil { + return nil, err + } + locations := make(map[string]v1.Location) + allowedZones := []string{"hayesvalley"} + if resp != nil { + for _, zone := range resp.Data { + var available = false + if len(zone.AvailableCapacity) > 0 && zone.DeliveryType == "VM" && slices.Contains(allowedZones, zone.Name) == true { + available = true + locations[zone.Name] = v1.Location{ + Name: zone.Name, + Description: fmt.Sprintf("sfc_%s_%s", zone.Name, string(zone.HardwareType)), + Available: available} + } else { + available = false + locations[zone.Name] = v1.Location{ + Name: zone.Name, + Description: fmt.Sprintf("sfc_%s_%s", zone.Name, string(zone.HardwareType)), + Available: false} + } + } + } + availableLocations := []v1.Location{} + for _, location := range locations { + availableLocations = append(availableLocations, location) + } + return availableLocations, nil +} diff --git a/v1/providers/sfcompute/validation_test.go b/v1/providers/sfcompute/validation_test.go new file mode 100644 index 0000000..1785e2c --- /dev/null +++ b/v1/providers/sfcompute/validation_test.go @@ -0,0 +1,37 @@ +package v1 + +import ( + "os" + "testing" + + "github.com/brevdev/cloud/internal/validation" + v1 "github.com/brevdev/cloud/v1" +) + +func TestValidationFunctions(t *testing.T) { + checkSkip(t) + apiKey := getAPIKey() + + config := validation.ProviderConfig{ + Credential: NewSFCCredential("validation-test", apiKey), + StableIDs: []v1.InstanceTypeID{ + "h100v_hayesvalley", + }, + } + + validation.RunValidationSuite(t, config) +} + +func checkSkip(t *testing.T) { + apiKey := getAPIKey() + isValidation := os.Getenv("VALIDATION_TEST") + if apiKey == "" && isValidation != "true" { + t.Fatal("SFCOMPUTE_API_KEY not set, but VALIDATION_TEST is set") + } else if apiKey == "" && isValidation == "false" { + t.Skip("SFCOMPUTE_API_KEY not set, skipping sfcompute validation tests") + } +} + +func getAPIKey() string { + return os.Getenv("SFCOMPUTE_API_KEY") +}