diff --git a/go.mod b/go.mod index 44803cd..abb53bd 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,13 @@ require ( sigs.k8s.io/aws-iam-authenticator v0.6.29 ) +replace ( + github.com/mattn/go-sqlite3 v1.14.16 => github.com/mattn/go-sqlite3 v1.14.18 + golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 => golang.org/x/crypto v0.1.0 + golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550 => golang.org/x/crypto v0.1.0 + golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9 => golang.org/x/crypto v0.1.0 +) + require ( github.com/aws/aws-sdk-go-v2 v1.34.0 // indirect github.com/aws/smithy-go v1.22.2 // indirect diff --git a/go.sum b/go.sum index a683a54..d7f6e98 100644 --- a/go.sum +++ b/go.sum @@ -114,6 +114,9 @@ golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20200226121028-0de0cce0169b/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20201021035429-f5854403a974/go.mod h1:sp8m0HH+o8qH0wwXwYZr8TS3Oi6o0r6Gce1SSxlDquU= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco= golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= golang.org/x/oauth2 v0.25.0 h1:CY4y7XT9v0cRI9oupztF8AgiIu99L/ksR/Xp/6jrZ70= @@ -125,12 +128,19 @@ golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5h golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU= golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.1.0/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.28.0 h1:/Ts8HFuMR2E6IP/jlo7QVLZHggjKQbhu/7H0LJFr3Gg= golang.org/x/term v0.28.0/go.mod h1:Sw/lC2IAUZ92udQNf3WodGtn4k/XoLyZoh8v/8uiwek= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/time v0.9.0 h1:EsRrnYcQiGH+5FfbgvV4AP7qEZstoyrHB0DzarOQ4ZY= @@ -139,6 +149,7 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20200619180055-7c47624df98f/go.mod h1:EkVYQZoAsY45+roYkvgYkIh4xh/qjgUK9TdY2XT94GE= golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/tools v0.26.0 h1:v/60pFQmzmT9ExmjDv2gGIfi3OqfKoEP6I5+umXlbnQ= golang.org/x/tools v0.26.0/go.mod h1:TPVVj70c7JJ3WCazhD8OdXcZg/og+b9+tH/KxylGwH0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/hatchery/config.go b/hatchery/config.go index e9871c8..68ef9e3 100644 --- a/hatchery/config.go +++ b/hatchery/config.go @@ -1,6 +1,7 @@ package hatchery import ( + "github.com/aws/aws-sdk-go/service/costexplorer/costexploreriface" "github.com/aws/aws-sdk-go/service/dynamodb/dynamodbiface" k8sv1 "k8s.io/api/core/v1" @@ -110,6 +111,10 @@ type AllPayModels struct { PayModels []PayModel `json:"all_pay_models"` } +type CostExplorerClient struct { + CostExporer costexploreriface.CostExplorerAPI +} + type DbConfig struct { DynamoDb dynamodbiface.DynamoDBAPI } @@ -133,7 +138,11 @@ type HatcheryConfig struct { Sidecar SidecarContainer `json:"sidecar"` MoreConfigs []AppConfigInfo `json:"more-configs"` PrismaConfig PrismaConfig `json:"prisma"` + Karpenter bool `json:"karpenter"` + DefaultHardLimit float32 `json:"default-hard-limit"` + DefaultSoftLimit float32 `json:"default-soft-limit"` NextflowGlobalConfig NextflowGlobalConfig `json:"nextflow-global"` + Developement bool `json:"developement"` } // Config to allow for Prisma Agents diff --git a/hatchery/cur.go b/hatchery/cur.go new file mode 100644 index 0000000..9e2f7de --- /dev/null +++ b/hatchery/cur.go @@ -0,0 +1,86 @@ +package hatchery + +import ( + "fmt" + "strconv" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/session" + "github.com/aws/aws-sdk-go/service/costexplorer" +) + +type costUsage struct { + Username string `json:"username"` + TotalCost float64 `json:"total-cost"` +} + +// This function will get called in the module that calls `getCostUsageRepot` +var initializeCostExplorerClient = func() *CostExplorerClient { + // Create an interface to CostExplorer service client + sess := session.Must(session.NewSessionWithOptions(session.Options{ + SharedConfigState: session.SharedConfigEnable, + })) + return &CostExplorerClient{ + CostExporer: costexplorer.New(sess), + } +} + +var getCostUsageReport = func(costexplorerclient *CostExplorerClient, username string, workflowname string) (*costUsage, error) { + // query cost usage report + // the CostExplorer service client is passed in as a parameter + + // Build the request with date range and filter + // Return costs by tags + req := &costexplorer.GetCostAndUsageInput{ + Metrics: []*string{ + aws.String("UnblendedCost"), + }, + TimePeriod: &costexplorer.DateInterval{ + // 1 year ago is max + Start: aws.String(time.Now().AddDate(-1, 0, 0).Format("2006-01-02")), + // Today + End: aws.String(time.Now().Format("2006-01-02")), + }, + Filter: &costexplorer.Expression{ + Tags: &costexplorer.TagValues{ + Key: aws.String("gen3username"), + Values: []*string{ + aws.String(userToResourceName(username, "user")), + }, + }, + }, + Granularity: aws.String("MONTHLY"), + } + + if workflowname != "" { + req.Filter = &costexplorer.Expression{ + Tags: &costexplorer.TagValues{ + Key: aws.String("gen3username"), + Values: []*string{ + aws.String(userToResourceName(username, "user")), + }, + }, + } + } + + // Call Cost Explorer API + resp, err := costexplorerclient.CostExporer.GetCostAndUsage(req) + if err != nil { + fmt.Println("Got error calling GetCostAndUsage:", err) + return nil, err + } + var total float64 + for _, result := range resp.ResultsByTime { + // Get amount + totalAmount := result.Total["UnblendedCost"] + amount, _ := strconv.ParseFloat(*totalAmount.Amount, 64) + + // Sum amounts + total += amount + } + + ret := costUsage{Username: username, TotalCost: total} + + return &ret, nil +} diff --git a/hatchery/cur_test.go b/hatchery/cur_test.go new file mode 100644 index 0000000..c2f51b8 --- /dev/null +++ b/hatchery/cur_test.go @@ -0,0 +1,100 @@ +package hatchery + +import ( + "reflect" + "testing" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/costexplorer" + "github.com/aws/aws-sdk-go/service/costexplorer/costexploreriface" +) + +type MockCostOutput struct { + output costexplorer.GetCostAndUsageOutput +} + +type CostExplorerMockClient struct { + costexploreriface.CostExplorerAPI + mockOutput *MockCostOutput +} + +func (m *CostExplorerMockClient) GetCostAndUsage(input *costexplorer.GetCostAndUsageInput) (*costexplorer.GetCostAndUsageOutput, error) { + return &m.mockOutput.output, nil +} + +func Test_GetCostUsageReport(t *testing.T) { + defer SetupAndTeardownTest()() + + testCases := []struct { + name string + want *costUsage + mockCostOutput *MockCostOutput + }{ + { + name: "UserHasCosts", + want: &costUsage{Username: "test_user", TotalCost: 100}, + mockCostOutput: &MockCostOutput{ + output: costexplorer.GetCostAndUsageOutput{ + ResultsByTime: []*costexplorer.ResultByTime{ + { + TimePeriod: &costexplorer.DateInterval{ + // 1 year ago is max + Start: aws.String(time.Now().AddDate(-1, 0, 0).Format("2006-01-02")), + // Today + End: aws.String(time.Now().Format("2006-01-02")), + }, + Total: map[string]*costexplorer.MetricValue{ + "UnblendedCost": { + Amount: aws.String("100"), + Unit: aws.String("USD"), + }, + }, + }, + }, + }, + }, + }, + { + name: "NoUserCosts", + want: &costUsage{Username: "test_user", TotalCost: 0}, + mockCostOutput: &MockCostOutput{}, + }, + } + + // Backing up original functions before mocking + original_getCostUsageReport := getCostUsageReport + defer func() { + // restore original functions + getCostUsageReport = original_getCostUsageReport + }() + + // mock the cost explorer interface and cost report + costexplorerclient := initializeCostExplorerClient() + + for _, testcase := range testCases { + t.Logf("Testing GetCostUsageReport when %s", testcase.name) + + costexplorerclient.CostExporer = &CostExplorerMockClient{ + CostExplorerAPI: nil, + mockOutput: testcase.mockCostOutput, + } + + /* Act */ + got, err := getCostUsageReport(costexplorerclient, "test_user", "Direct Pay") + if nil != err { + t.Errorf("failed to get cost usage report, got: %v", err) + return + } + + /* Assert */ + if reflect.TypeOf(got) != reflect.TypeOf(testcase.want) { + t.Errorf("Return value is not correct type:\ngot: '%v'\nwant: '%v'", + reflect.TypeOf(got), reflect.TypeOf(testcase.want)) + } + if !reflect.DeepEqual(got, testcase.want) { + t.Errorf("\nassertion error while testing `getCostUsageReport`: \nWant:%+v\nGot:%+v", testcase.want, got) + } + + } +} diff --git a/hatchery/hatchery.go b/hatchery/hatchery.go index 8fc760d..e11e9d3 100644 --- a/hatchery/hatchery.go +++ b/hatchery/hatchery.go @@ -46,11 +46,47 @@ func RegisterHatchery() { http.HandleFunc("/setpaymodel", setpaymodel) http.HandleFunc("/resetpaymodels", resetPaymodels) http.HandleFunc("/allpaymodels", allpaymodels) + http.HandleFunc("/cost", cost) // ECS functions http.HandleFunc("/create-ecs-cluster", createECSCluster) } +func cost(w http.ResponseWriter, r *http.Request) { + // create context for http call + // context, cancel := context.WithTimeout(r.Context(), 30*time.Second) + // defer cancel() + + userName := getCurrentUserName(r) + + workflowname := r.URL.Query().Get("workflowname") + // check if workflowname is empty + + // get cost usage report + costexplorerclient := initializeCostExplorerClient() + costUsageReport, err := getCostUsageReport(costexplorerclient, userName, workflowname) + if err != nil { + Config.Logger.Print(err) + // Send 500 error + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + + // make a return object with username and cost + cur, err := json.Marshal(costUsageReport) + if err != nil { + Config.Logger.Print(err) + // Send 500 error + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + // Add header indicating this is a json response + w.Header().Set("Content-Type", "application/json") + + // return json + fmt.Fprint(w, string(cur)) +} + func home(w http.ResponseWriter, r *http.Request) { htmlHeader := ` Gen3 Hatchery @@ -118,6 +154,8 @@ func paymodels(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusInternalServerError) return } + // Set header to indicate it's a json response + w.Header().Set("Content-Type", "application/json") fmt.Fprint(w, string(out)) } @@ -142,6 +180,8 @@ func allpaymodels(w http.ResponseWriter, r *http.Request) { http.Error(w, err.Error(), http.StatusInternalServerError) return } + // Set header to indicate it's a json response + w.Header().Set("Content-Type", "application/json") fmt.Fprint(w, string(out)) } @@ -479,6 +519,7 @@ func terminate(w http.ResponseWriter, r *http.Request) { http.Error(w, "Method Not Allowed", http.StatusMethodNotAllowed) return } + accessToken := getBearerToken(r) userName := getCurrentUserName(r) if userName == "" { @@ -511,10 +552,10 @@ func terminate(w http.ResponseWriter, r *http.Request) { // delete nextflow resources. There is no way to know if the actual workspace being // terminated is a nextflow workspace or not, so always attempt to delete - Config.Logger.Printf("Info: Deleting Nextflow resources in AWS...") + Config.Logger.Print("Info: Deleting Nextflow resources in AWS...") err := cleanUpNextflowResources(userName, nil, nil, nil, nil) if err != nil { - Config.Logger.Printf("Unable to delete AWS resources for Nextflow... continuing anyway") + Config.Logger.Print("Unable to delete AWS resources for Nextflow... continuing anyway") } payModel, err := getCurrentPayModel(userName) @@ -528,7 +569,7 @@ func terminate(w http.ResponseWriter, r *http.Request) { return } else { Config.Logger.Printf("Succesfully terminated all resources related to ECS workspace for user %s", userName) - fmt.Fprintf(w, "Terminated ECS workspace") + fmt.Fprint(w, "Terminated ECS workspace") } } else { err := deleteK8sPod(r.Context(), userName, accessToken, payModel) @@ -537,27 +578,30 @@ func terminate(w http.ResponseWriter, r *http.Request) { return } Config.Logger.Printf("Terminated workspace for user %s", userName) - fmt.Fprintf(w, "Terminated workspace") + fmt.Fprint(w, "Terminated workspace") } - // Need to reset pay model only after workspace termination is completed. - go func() { - // Periodically poll for status, until it is set as "Not Found" - for { - status, err := getWorkspaceStatus(r.Context(), userName, accessToken) - if err != nil { - Config.Logger.Printf("error fetching workspace status for user %s\n err: %s", userName, err) + // check if dynamoDB is enabled + if Config.Config.PayModelsDynamodbTable != "" { + // Need to reset pay model only after workspace termination is completed. + go func() { + // Periodically poll for status, until it is set as "Not Found" + for { + status, err := getWorkspaceStatus(r.Context(), userName, accessToken) + if err != nil { + Config.Logger.Printf("error fetching workspace status for user %s\n err: %s", userName, err) + } + if status.Status == "Not Found" { + break + } + time.Sleep(5 * time.Second) } - if status.Status == "Not Found" { - break + err = resetCurrentPaymodel(userName) + if err != nil { + Config.Logger.Printf("unable to reset current paymodel for current user %s\nerr: %s", userName, err) } - time.Sleep(5 * time.Second) - } - err = resetCurrentPaymodel(userName) - if err != nil { - Config.Logger.Printf("unable to reset current paymodel for current user %s\nerr: %s", userName, err) - } - }() + }() + } } func getBearerToken(r *http.Request) string { diff --git a/hatchery/hatchery_test.go b/hatchery/hatchery_test.go index 426ea1f..6c2a031 100644 --- a/hatchery/hatchery_test.go +++ b/hatchery/hatchery_test.go @@ -730,6 +730,7 @@ func Test_TerminateEndpoint(t *testing.T) { waitToTerminate bool throwError bool calledFunctionName string + noPayModelTable bool }{ { name: "MethodIsNotPost", @@ -758,6 +759,18 @@ func Test_TerminateEndpoint(t *testing.T) { mockCurrentPayModel: nil, calledFunctionName: "deleteK8sPod", }, + { + name: "NoPayModelDBTableExists", + want: "Terminated workspace", + wantStatus: http.StatusOK, + mockRequest: &RequestBody{ + Method: "POST", + username: "testUser", + }, + mockCurrentPayModel: nil, + calledFunctionName: "deleteK8sPod", + noPayModelTable: true, + }, { name: "NonEcsPayModelExists", want: "Terminated workspace", @@ -837,6 +850,7 @@ func Test_TerminateEndpoint(t *testing.T) { original_getLicenseUserMapsForUser := getLicenseUserMapsForUser original_getWorkspaceStatus := getWorkspaceStatus original_resetCurrentPaymodel := resetCurrentPaymodel + original_payModelTable := Config.Config.PayModelsDynamodbTable defer func() { // restore original functions deleteK8sPod = original_deleteK8sPod @@ -845,6 +859,7 @@ func Test_TerminateEndpoint(t *testing.T) { getLicenseUserMapsForUser = original_getLicenseUserMapsForUser getWorkspaceStatus = original_getWorkspaceStatus resetCurrentPaymodel = original_resetCurrentPaymodel + Config.Config.PayModelsDynamodbTable = original_payModelTable }() for _, testcase := range testCases { @@ -867,6 +882,9 @@ func Test_TerminateEndpoint(t *testing.T) { if testcase.throwError { return errors.New("error deleting k8s pod") } + if testcase.noPayModelTable { + waitGroup.Done() + } return nil } terminateEcsWorkspace = func(ctx context.Context, userName, accessToken, awsAcctID string) (string, error) { @@ -875,6 +893,9 @@ func Test_TerminateEndpoint(t *testing.T) { if testcase.throwError { return "", errors.New("error terminating ecs workspace") } + if testcase.noPayModelTable { + waitGroup.Done() + } return "", nil } @@ -898,6 +919,12 @@ func Test_TerminateEndpoint(t *testing.T) { return nil } + if testcase.noPayModelTable { + Config.Config.PayModelsDynamodbTable = "" + } else { + Config.Config.PayModelsDynamodbTable = "paymodelTableName" + } + getLicenseUserMapsForUser = func(dbconfig *DbConfig, userId string) ([]Gen3LicenseUserMap, error) { return []Gen3LicenseUserMap{}, nil } @@ -948,11 +975,19 @@ func Test_TerminateEndpoint(t *testing.T) { t.Errorf("Expected to call workspaceStatus more than once , but is called %d time(s)", workspaceStatusCallCounter) } + if testcase.noPayModelTable { + if !testcase.waitToTerminate && workspaceStatusCallCounter != 0 { + t.Errorf("Expected to call workspaceStatus exactly 0 times , but is called %d time(s)", + workspaceStatusCallCounter) + } + } else { + if !testcase.waitToTerminate && workspaceStatusCallCounter != 1 { + t.Errorf("Expected to call workspaceStatus exactly once , but is called %d time(s)", + workspaceStatusCallCounter) + } - if !testcase.waitToTerminate && workspaceStatusCallCounter != 1 { - t.Errorf("Expected to call workspaceStatus exactly once , but is called %d time(s)", - workspaceStatusCallCounter) } + } } } @@ -1129,3 +1164,113 @@ aws { return } } + +func Test_CostEndpoint(t *testing.T) { + defer SetupAndTeardownTest()() + + type RequestBody struct { + Method string + username string + workflowname string + } + + testCases := []struct { + name string + want string + wantStatus int + mockRequest *RequestBody + mockCostUsageReport *costUsage + mockCurrentPayModel *PayModel + waitToTerminate bool + throwError bool + calledFunctionName string + noPayModelTable bool + }{ + { + name: "UserHasTotalCost", + want: `{"username":"testUser","total-cost":2.5}`, + wantStatus: http.StatusOK, + mockRequest: &RequestBody{ + Method: "GET", + username: "testUser", + workflowname: "Direct+Pay", + }, + mockCostUsageReport: &costUsage{ + Username: "testUser", + TotalCost: 2.5, + }, + }, + { + name: "MissingUsername", + want: `{"username":"","total-cost":0}`, + wantStatus: http.StatusOK, + mockRequest: &RequestBody{ + Method: "GET", + }, + mockCostUsageReport: &costUsage{ + Username: "", + TotalCost: 0, + }, + }, + { + name: "MissingWorkflowname", + want: `{"username":"testUser","total-cost":2.5}`, + wantStatus: http.StatusOK, + mockRequest: &RequestBody{ + Method: "GET", + username: "testUser", + }, + mockCostUsageReport: &costUsage{ + Username: "testUser", + TotalCost: 2.5, + }, + }, + } + + // Backing up original functions before mocking + original_getCostUsageReport := getCostUsageReport + defer func() { + // restore original functions + getCostUsageReport = original_getCostUsageReport + }() + + for _, testcase := range testCases { + t.Logf("Testing Terminate Endpoint when %s", testcase.name) + + /* Setup */ + getCostUsageReport = func(costexplorerclient *CostExplorerClient, username string, workflowname string) (*costUsage, error) { + fmt.Println("COST mock function called") + return testcase.mockCostUsageReport, nil + } + + url := "/cost" + if testcase.mockRequest.workflowname != "" { + url = url + "?workflowname=" + testcase.mockRequest.workflowname + } + fmt.Printf("TEST: url %s\n", url) + req, err := http.NewRequest(testcase.mockRequest.Method, url, nil) + if testcase.mockRequest.username != "" { + req.Header.Set("REMOTE_USER", testcase.mockRequest.username) + } + if err != nil { + t.Fatal(err) + } + w := httptest.NewRecorder() + + /* Act */ + handler := http.HandlerFunc(cost) + handler.ServeHTTP(w, req) + + /* Assert */ + if testcase.wantStatus != w.Code { + t.Errorf("handler returned wrong status code:\ngot: '%v'\nwant: '%v'", + w.Code, testcase.wantStatus) + } + + if testcase.want != strings.TrimSpace(w.Body.String()) { + t.Errorf("handler returned wrong response:\ngot: '%v'\nwant: '%v'", + w.Body.String(), testcase.want) + } + } + +} diff --git a/hatchery/karpenter.go b/hatchery/karpenter.go new file mode 100644 index 0000000..2a8fbbc --- /dev/null +++ b/hatchery/karpenter.go @@ -0,0 +1,250 @@ +package hatchery + +import ( + "context" + "fmt" + "os" + + apierrors "k8s.io/apimachinery/pkg/api/errors" + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/apis/meta/v1/unstructured" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/client-go/dynamic" + "k8s.io/client-go/rest" + // "sigs.k8s.io/controller-runtime/pkg/client" +) + +// Create karpenter AWSNodeTemplate +func createKarpenterAWSNodeTemplate(ctx context.Context, userName string, client dynamic.Interface) error { + + jupyterTemplate, err := getJupyterAWSNodeTemplate(ctx, client) + if err != nil { + return err + } + + // Create a unstructured object. + u := &unstructured.Unstructured{} + + u.Object = map[string]interface{}{ + "metadata": map[string]interface{}{ + "name": userToResourceName(userName, "pod"), + }, + "spec": jupyterTemplate.Object["spec"], + } + + // Update tags + u.Object["spec"].(map[string]interface{})["tags"] = map[string]interface{}{ + "Name": fmt.Sprintf("eks-%s-jupyter-karpenter", os.Getenv("GEN3_VPCID")), + "Environment": os.Getenv("GEN3_VPCID"), + "Organization": os.Getenv("GEN3_ENDPOINT"), + "karpenter.sh/discovery": os.Getenv("GEN3_VPCID"), + "gen3.io/role": userToResourceName(userName, "pod"), + "gen3username": userToResourceName(userName, "user"), + "gen3.io/environment": os.Getenv("GEN3_ENDPOINT"), + } + u.SetGroupVersionKind(schema.GroupVersionKind{ + Group: "karpenter.k8s.aws", + Version: "v1alpha1", + Kind: "AWSNodeTemplate", + }) + + gvr := schema.GroupVersionResource{ + Group: "karpenter.k8s.aws", + Version: "v1alpha1", + Resource: "awsnodetemplates", + } + + // Delete the AWSNodeTemplate if it exists + err = client.Resource(gvr).Delete(ctx, userToResourceName(userName, "pod"), metav1.DeleteOptions{}) + if err != nil { + if !apierrors.IsNotFound(err) { + Config.Logger.Print("Error deleting Karpenter AWSNodeTemplate") + return err + } + } + // Create the AWSNodeTemplate. + _, err = client.Resource(gvr).Create(ctx, u, metav1.CreateOptions{}) + if err != nil { + Config.Logger.Print("Error deleting Karpenter AWSNodeTemplate") + return err + } + Config.Logger.Print("Created Karpenter AWSNodeTemplate") + return nil +} + +// // Create karpenter provisioner +func createKarpenterProvisioner(ctx context.Context, userName string, client dynamic.Interface) error { + + jupyterTemplate, err := getJupyterProvisioner(client) + if err != nil { + Config.Logger.Print("Error getting Jupyter provisioner: ", err) + return err + } + + // // Using an unstructured object. + u := &unstructured.Unstructured{} + u.Object = map[string]interface{}{ + "metadata": map[string]interface{}{ + "name": userToResourceName(userName, "pod"), + }, + "spec": jupyterTemplate.Object["spec"], + } + + // Update role to match the user + u.Object["spec"].(map[string]interface{})["labels"] = map[string]interface{}{ + "role": userToResourceName(userName, "pod"), + } + + // update taints to match users role + u.Object["spec"].(map[string]interface{})["taints"] = []interface{}{ + map[string]interface{}{ + "key": "role", + "value": userToResourceName(userName, "pod"), + "effect": "NoSchedule", + }, + } + + // update providerref + u.Object["spec"].(map[string]interface{})["providerRef"] = map[string]interface{}{ + "name": userToResourceName(userName, "pod"), + } + + u.SetGroupVersionKind(schema.GroupVersionKind{ + Group: "karpenter.sh", + Kind: "Provisioner", + Version: "v1alpha5", + }) + + gvr := schema.GroupVersionResource{ + Group: "karpenter.sh", + Version: "v1alpha5", + Resource: "provisioners", + } + + // Check if the provisioner exists already, if it does delete the existing one + err = client.Resource(gvr).Delete(ctx, userToResourceName(userName, "pod"), metav1.DeleteOptions{}) + if err != nil { + if !apierrors.IsNotFound(err) { + Config.Logger.Print("Error deleting Karpenter provisioner") + return err + } + } + + // crete the provisioner + _, err = client.Resource(gvr).Create(ctx, u, metav1.CreateOptions{}) + if err != nil { + Config.Logger.Print("Error creating Karpenter provisioner inside here") + return err + } + Config.Logger.Printf("Created Karpenter provisioner %s... \n", userToResourceName(userName, "pod")) + return nil +} + +// Delete karpenter provisioner +func deleteKarpenterProvisioner(ctx context.Context, userName string, client dynamic.Interface) error { + // Delete the provisioner + err := client.Resource(schema.GroupVersionResource{ + Group: "karpenter.sh", + Version: "v1alpha5", + Resource: "provisioners", + }).Delete(ctx, userToResourceName(userName, "pod"), metav1.DeleteOptions{}) + if err != nil { + if !apierrors.IsNotFound(err) { + return err + } + } + return nil +} + +// Delete karpenter AWSNodeTemplate +func deleteKarpenterAWSNodeTemplate(ctx context.Context, userName string, client dynamic.Interface) error { + + // Delete the AWSNodeTemplate + err := client.Resource(schema.GroupVersionResource{ + Group: "karpenter.k8s.aws", + Version: "v1alpha1", + Resource: "awsnodetemplates", + }).Delete(ctx, userToResourceName(userName, "pod"), metav1.DeleteOptions{}) + if err != nil { + if !apierrors.IsNotFound(err) { + return err + } + } + return nil + +} + +func getJupyterAWSNodeTemplate(ctx context.Context, client dynamic.Interface) (*unstructured.Unstructured, error) { + + res, err := client.Resource(schema.GroupVersionResource{ + Group: "karpenter.k8s.aws", + Version: "v1alpha1", + Resource: "awsnodetemplates", + }).Get(context.Background(), "jupyter", metav1.GetOptions{}) + if err != nil { + Config.Logger.Print("Error getting Jupyter AWSNodeTemplate: ", err) + return nil, err + } + + return res, nil +} + +func getJupyterProvisioner(client dynamic.Interface) (*unstructured.Unstructured, error) { + + res, err := client.Resource(schema.GroupVersionResource{ + Group: "karpenter.sh", + Version: "v1alpha5", + Resource: "provisioners", + }).Get(context.TODO(), "jupyter", metav1.GetOptions{}) + if err != nil { + Config.Logger.Print("Error getting Jupyter provisioner: ", err) + return nil, err + } + return res, nil +} + +func createKarpenterResources(userName string) error { + // creates the in-cluster config + config, err := getKubeConfig() + if err != nil { + Config.Logger.Printf("Error creating kubeconfig: %v", err) + return err + } + // create context + ctx := context.Background() + + // create dynamic client + client, err := dynamic.NewForConfig(config) + if err != nil { + Config.Logger.Printf("Error creating dynamic client: %v", err) + return err + } + + err = createKarpenterAWSNodeTemplate(ctx, userName, client) + if err != nil { + return err + } + err = createKarpenterProvisioner(ctx, userName, client) + if err != nil { + return err + } + return nil +} + +func deleteKarpenterResources(ctx context.Context, userName string, config *rest.Config) error { + // create dynamic client + client, err := dynamic.NewForConfig(config) + if err != nil { + Config.Logger.Printf("Error creating dynamic client: %v", err) + return err + } + err = deleteKarpenterAWSNodeTemplate(ctx, userName, client) + if err != nil { + return err + } + err = deleteKarpenterProvisioner(ctx, userName, client) + if err != nil { + return err + } + return nil +} diff --git a/hatchery/paymodels.go b/hatchery/paymodels.go index b34ddee..8018cf3 100644 --- a/hatchery/paymodels.go +++ b/hatchery/paymodels.go @@ -3,6 +3,7 @@ package hatchery import ( "errors" "fmt" + "strings" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/credentials/stscreds" @@ -135,6 +136,25 @@ var getCurrentPayModel = func(userName string) (result *PayModel, err error) { Config.Logger.Printf("Got error unmarshalling: %s", err) return nil, err } + + if payModel.Local && Config.Config.Karpenter && strings.Contains(strings.ToLower(payModel.Name), "trial") { + + // get cost usage report + costexplorerclient := initializeCostExplorerClient() + costUsage, err := getCostUsageReport(costexplorerclient, userName, "") + if err != nil { + Config.Logger.Printf("Got error getting cost usage report: %s", err) + return nil, err + } + payModel.TotalUsage = float32(costUsage.TotalCost) + if payModel.HardLimit == 0 && Config.Config.DefaultHardLimit != 0 { + payModel.HardLimit = Config.Config.DefaultHardLimit + } + if payModel.SoftLimit == 0 && Config.Config.DefaultSoftLimit != 0 { + payModel.SoftLimit = Config.Config.DefaultSoftLimit + } + } + return &payModel, nil } diff --git a/hatchery/paymodels_test.go b/hatchery/paymodels_test.go index 634cfdc..393b241 100644 --- a/hatchery/paymodels_test.go +++ b/hatchery/paymodels_test.go @@ -179,6 +179,247 @@ func Test_GetCurrentPayModel(t *testing.T) { } } } + +func Test_GetCurrentPayModelWithLimits(t *testing.T) { + defer SetupAndTeardownTest()() + + configWithLimits := &FullHatcheryConfig{ + Config: HatcheryConfig{ + PayModelsDynamodbTable: "random_non_empty_string", + Karpenter: true, + DefaultHardLimit: float32(12), + DefaultSoftLimit: float32(7), + }, + } + + configWithNoLimits := &FullHatcheryConfig{ + Config: HatcheryConfig{ + PayModelsDynamodbTable: "random_non_empty_string", + Karpenter: true, + }, + } + + configNoKarpenter := &FullHatcheryConfig{ + Config: HatcheryConfig{ + PayModelsDynamodbTable: "random_non_empty_string", + Karpenter: false, + }, + } + + defaultPayModelForTest := &PayModel{ + Name: "Trial Workspace", + Local: true, + } + + testCases := []struct { + name string + want *PayModel + mockConfig *FullHatcheryConfig + mockDefaultPaymodel *PayModel + mockCurrentPayModelFromDB []PayModel + mockPayModelsFromDB []PayModel + }{ + { + name: "CurrentPayModelDBHasLimits", + want: &PayModel{ + Id: "#1", + Name: "Trial Workspace", + Local: true, + CurrentPayModel: true, + Status: "active", + HardLimit: 10, + SoftLimit: 5, + }, + mockConfig: configWithLimits, + mockCurrentPayModelFromDB: []PayModel{ + { + Id: "#1", + Name: "Trial Workspace", + Local: true, + CurrentPayModel: true, + Status: "active", + HardLimit: 10, + SoftLimit: 5, + }, + }, + mockPayModelsFromDB: []PayModel{ + { + Id: "#1", + Name: "Direct Pay", + Local: true, + CurrentPayModel: true, + Status: "active", + HardLimit: 10, + SoftLimit: 5, + }, + { + Id: "#2", + Name: "Direct Pay", + CurrentPayModel: false, + Status: "active", + }, + }, + mockDefaultPaymodel: nil, + }, + { + name: "CurrentPayModelConfigLimits", + want: &PayModel{ + Id: "#1", + Name: "Trial Workspace", + Local: true, + CurrentPayModel: true, + Status: "active", + HardLimit: 12, + SoftLimit: 7, + }, + mockConfig: configWithLimits, + mockCurrentPayModelFromDB: []PayModel{ + { + Id: "#1", + Name: "Trial Workspace", + Local: true, + CurrentPayModel: true, + Status: "active", + }, + }, + mockPayModelsFromDB: []PayModel{ + { + Id: "#1", + Name: "Trial Workspace", + Local: true, + CurrentPayModel: true, + Status: "active", + }, + { + Id: "#2", + Name: "Trial Workspace", + CurrentPayModel: false, + Status: "active", + }, + }, + mockDefaultPaymodel: nil, + }, + { + name: "CurrentPayModelNoKarpenter", + want: &PayModel{ + Id: "#1", + Name: "Trial Workspace", + Local: true, + CurrentPayModel: true, + Status: "active", + }, + mockConfig: configNoKarpenter, + mockCurrentPayModelFromDB: []PayModel{ + { + Id: "#1", + Name: "Trial Workspace", + Local: true, + CurrentPayModel: true, + Status: "active", + }, + }, + mockPayModelsFromDB: []PayModel{ + { + Id: "#1", + Name: "Trial Workspace", + Local: true, + CurrentPayModel: true, + Status: "active", + }, + { + Id: "#2", + Name: "Trial Workspace", + CurrentPayModel: false, + Status: "active", + }, + }, + mockDefaultPaymodel: nil, + }, + { + name: "NeitherCurrentPayModelNorConfigHaveLimits", + want: &PayModel{ + Id: "#1", + Name: "Trial Workspace", + Local: true, + CurrentPayModel: true, + Status: "active", + }, + mockConfig: configWithNoLimits, + mockCurrentPayModelFromDB: []PayModel{ + { + Id: "#1", + Name: "Trial Workspace", + Local: true, + CurrentPayModel: true, + Status: "active", + }, + }, + mockPayModelsFromDB: []PayModel{ + { + Id: "#1", + Name: "Trial Workspace", + Local: true, + CurrentPayModel: true, + Status: "active", + }, + { + Id: "#2", + Name: "Trial Workspace", + CurrentPayModel: false, + Status: "active", + }, + }, + mockDefaultPaymodel: defaultPayModelForTest, + }, + } + + // Backing up original functions before mocking + original_getDefaultPayModel := getDefaultPayModel + original_payModelsFromDatabase := payModelsFromDatabase + original_getCostUsageReport := getCostUsageReport + defer func() { + // restore original functions + getDefaultPayModel = original_getDefaultPayModel + payModelsFromDatabase = original_payModelsFromDatabase + getCostUsageReport = original_getCostUsageReport + }() + + for _, testcase := range testCases { + t.Logf("Testing GetCurrentPaymodelWithLimits when %s", testcase.name) + + /* Setup */ + Config = testcase.mockConfig + getDefaultPayModel = func() (*PayModel, error) { + return testcase.mockDefaultPaymodel, nil + } + payModelsFromDatabase = func(userName string, current bool) (payModels *[]PayModel, err error) { + if current { + return &testcase.mockCurrentPayModelFromDB, nil + } + return &testcase.mockPayModelsFromDB, nil + } + getCostUsageReport = func(costexplorerclient *CostExplorerClient, username string, workflowname string) (*costUsage, error) { + return &costUsage{Username: username, TotalCost: 0.00}, nil + } + + /* Act */ + got, err := getCurrentPayModel("testUser") + if nil != err { + t.Errorf("failed to load current pay model, got: %v", err) + return + } + + /* Assert */ + if testcase.want == nil { + if got != nil { + t.Errorf("\nassertion error while testing `GetPayModelsForUser` when %s : \nWant: %+v\nGot:%+v", testcase.name, testcase.want, got) + } + } else if !reflect.DeepEqual(got, testcase.want) { + t.Errorf("\nassertion error while testing `GetCurrentPayModel` when %s : \nWant:%+v\nGot:%+v", testcase.name, testcase.want, got) + } + } +} + func Test_GetPayModelsForUser(t *testing.T) { defer SetupAndTeardownTest()() diff --git a/hatchery/pods.go b/hatchery/pods.go index 7a035af..b64125b 100644 --- a/hatchery/pods.go +++ b/hatchery/pods.go @@ -6,6 +6,7 @@ import ( "fmt" "log" "os" + "path/filepath" "strconv" "strings" @@ -16,6 +17,7 @@ import ( "k8s.io/client-go/kubernetes" corev1 "k8s.io/client-go/kubernetes/typed/core/v1" "k8s.io/client-go/rest" + "k8s.io/client-go/tools/clientcmd" // AWS modules "github.com/aws/aws-sdk-go/aws" @@ -80,9 +82,25 @@ func getPodClient(ctx context.Context, userName string, payModelPtr *PayModel) ( } } +func getKubeConfig() (*rest.Config, error) { + config, err := rest.InClusterConfig() + if err != nil { + // Fallback to kubeconfig from .kube/config (for local testing) + kubeconfig := filepath.Join(os.Getenv("HOME"), ".kube", "config") + // Use kubeconfig + config, err = clientcmd.BuildConfigFromFlags("", kubeconfig) + if err != nil { + Config.Logger.Printf("Error creating in-cluster config: %v", err) + // Send 500 error + return nil, err + } + } + return config, nil +} + func getLocalPodClient() corev1.CoreV1Interface { // creates the in-cluster config - config, err := rest.InClusterConfig() + config, err := getKubeConfig() if err != nil { Config.Logger.Printf("Error creating in-cluster config: %v", err) return nil @@ -326,6 +344,19 @@ var deleteK8sPod = func(ctx context.Context, userName string, accessToken string fmt.Printf("Error occurred when deleting service: %s", err) } + Config.Logger.Print("Checking if karpenter stuff needs deleted") + if Config.Config.Karpenter { + Config.Logger.Print("Attempting to delete karpenter resources") + config, err := getKubeConfig() + if err != nil { + return err + } + err = deleteKarpenterResources(ctx, userName, config) + if err != nil { + Config.Logger.Print("Error occurred when deleting karpenter resources") + } + } + return nil } @@ -343,6 +374,9 @@ func userToResourceName(userName string, resourceType string) string { if resourceType == "mapping" { // ambassador mapping return fmt.Sprintf("%s-mapping", safeUserName) } + if resourceType == "user" { + return fmt.Sprintf("user-%s", safeUserName) + } return fmt.Sprintf("%s-%s", resourceType, safeUserName) } @@ -497,13 +531,25 @@ func buildPod(hatchConfig *FullHatcheryConfig, hatchApp *Container, userName str }) } + role := "jupyter" + if Config.Config.Karpenter { + role = userToResourceName(userName, "pod") + + err = createKarpenterResources(userName) + if err != nil { + Config.Logger.Print(err) + // Send 500 error + return nil, err + } + } + tolerations := []k8sv1.Toleration{} nodeSelector := map[string]string{} if !Config.Config.SkipNodeSelector { nodeSelector = map[string]string{ - "role": "jupyter", + "role": role, } - tolerations = []k8sv1.Toleration{{Key: "role", Operator: "Equal", Value: "jupyter", Effect: "NoSchedule", TolerationSeconds: nil}} + tolerations = []k8sv1.Toleration{{Key: "role", Operator: "Equal", Value: role, Effect: "NoSchedule", TolerationSeconds: nil}} } pod = &k8sv1.Pod{ diff --git a/hatchery/ram.go b/hatchery/ram.go index fb57562..01324a7 100644 --- a/hatchery/ram.go +++ b/hatchery/ram.go @@ -44,7 +44,7 @@ func acceptTransitGatewayShare(pm *PayModel, sess *session.Session, ramArn *stri } } else { // Log that resource share is already accepted - Config.Logger.Printf("Resource share already accepted") + Config.Logger.Print("Resource share already accepted") } return nil } @@ -71,7 +71,7 @@ func (creds *CREDS) acceptTGWShare(ramArn *string) error { // Check if we have an invitation to accept if len(resourceShareInvitation.ResourceShareInvitations) == 0 { // No invitation found, possible that we have to wait a bit for the invitation to show up. - Config.Logger.Printf("No resource share invitation found, waiting 10 seconds") + Config.Logger.Print("No resource share invitation found, waiting 10 seconds") time.Sleep(10 * time.Second) err := creds.acceptTGWShare(ramArn) @@ -88,11 +88,11 @@ func (creds *CREDS) acceptTGWShare(ramArn *string) error { return err } // Log that invitation was accepted - Config.Logger.Printf("Resource share invitation accepted") + Config.Logger.Print("Resource share invitation accepted") return nil } // Log that invitation was already accepted - Config.Logger.Printf("Resource share invitation already accepted") + Config.Logger.Print("Resource share invitation already accepted") return nil } } @@ -114,7 +114,7 @@ func shareTransitGateway(session *session.Session, tgwArn string, accountid stri return nil, err } if len(exRs.ResourceShares) == 0 { - Config.Logger.Printf("Did not find existing resource share, creating a resource share") + Config.Logger.Print("Did not find existing resource share, creating a resource share") resourceShareInput := &ram.CreateResourceShareInput{ // Indicates whether principals outside your organization in Organizations can // be associated with a resource share. @@ -139,7 +139,7 @@ func shareTransitGateway(session *session.Session, tgwArn string, accountid stri } return resourceShare.ResourceShare.ResourceShareArn, nil } else { - Config.Logger.Printf("Found existing resource share, associating resource share with account") + Config.Logger.Print("Found existing resource share, associating resource share with account") listResourcesInput := &ram.ListResourcesInput{ ResourceOwner: aws.String("SELF"), ResourceArns: []*string{&tgwArn}, diff --git a/main.go b/main.go index f66f3c8..f68df24 100644 --- a/main.go +++ b/main.go @@ -49,6 +49,14 @@ func main() { hatchery.RegisterSystem() hatchery.RegisterHatchery() + if config.Config.Karpenter { + config.Logger.Printf("Using karpenter for cost tracking.") + } + + if config.Config.Developement { + config.Logger.Printf("Using development mode.") + } + config.Logger.Printf("Running main") log.Fatal(http.ListenAndServe("0.0.0.0:8000", nil)) }