Skip to content

Commit

Permalink
Test that ray and dask plugins bump phase version in GetTaskPhase
Browse files Browse the repository at this point in the history
Signed-off-by: Fabio Graetz <[email protected]>
  • Loading branch information
fg91 committed Apr 11, 2024
1 parent 4d52bd8 commit 374242b
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 23 deletions.
43 changes: 30 additions & 13 deletions flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ func dummyDaskTaskTemplate(customImage string, resources *core.Resources, podTem
}
}

func dummyDaskTaskContext(taskTemplate *core.TaskTemplate, resources *v1.ResourceRequirements, extendedResources *core.ExtendedResources, isInterruptible bool) pluginsCore.TaskExecutionContext {
func dummyDaskTaskContext(taskTemplate *core.TaskTemplate, resources *v1.ResourceRequirements, extendedResources *core.ExtendedResources, isInterruptible bool, pluginState k8s.PluginState) pluginsCore.TaskExecutionContext {
taskCtx := &mocks.TaskExecutionContext{}

inputReader := &pluginIOMocks.InputReader{}
Expand Down Expand Up @@ -199,11 +199,10 @@ func dummyDaskTaskContext(taskTemplate *core.TaskTemplate, resources *v1.Resourc
taskExecutionMetadata.OnGetOverrides().Return(overrides)
taskCtx.On("TaskExecutionMetadata").Return(taskExecutionMetadata)

inputState := k8s.PluginState{}
pluginStateReaderMock := mocks.PluginStateReader{}
pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&inputState).String())).Return(
pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&pluginState).String())).Return(
func(v interface{}) uint8 {
*(v.(*k8s.PluginState)) = inputState
*(v.(*k8s.PluginState)) = pluginState
return 0
},
func(v interface{}) error {
Expand All @@ -218,7 +217,7 @@ func TestBuildResourceDaskHappyPath(t *testing.T) {
daskResourceHandler := daskResourceHandler{}

taskTemplate := dummyDaskTaskTemplate("", nil, "")
taskContext := dummyDaskTaskContext(taskTemplate, &defaultResources, nil, false)
taskContext := dummyDaskTaskContext(taskTemplate, &defaultResources, nil, false, k8s.PluginState{})
r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext)
assert.Nil(t, err)
assert.NotNil(t, r)
Expand Down Expand Up @@ -329,7 +328,7 @@ func TestBuildResourceDaskCustomImages(t *testing.T) {

daskResourceHandler := daskResourceHandler{}
taskTemplate := dummyDaskTaskTemplate(customImage, nil, "")
taskContext := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, nil, false)
taskContext := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, nil, false, k8s.PluginState{})
r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext)
assert.Nil(t, err)
assert.NotNil(t, r)
Expand Down Expand Up @@ -362,7 +361,7 @@ func TestBuildResourceDaskDefaultResoureRequirements(t *testing.T) {

daskResourceHandler := daskResourceHandler{}
taskTemplate := dummyDaskTaskTemplate("", nil, "")
taskContext := dummyDaskTaskContext(taskTemplate, &flyteWorkflowResources, nil, false)
taskContext := dummyDaskTaskContext(taskTemplate, &flyteWorkflowResources, nil, false, k8s.PluginState{})
r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext)
assert.Nil(t, err)
assert.NotNil(t, r)
Expand Down Expand Up @@ -419,7 +418,7 @@ func TestBuildResourcesDaskCustomResoureRequirements(t *testing.T) {

daskResourceHandler := daskResourceHandler{}
taskTemplate := dummyDaskTaskTemplate("", &protobufResources, "")
taskContext := dummyDaskTaskContext(taskTemplate, &flyteWorkflowResources, nil, false)
taskContext := dummyDaskTaskContext(taskTemplate, &flyteWorkflowResources, nil, false, k8s.PluginState{})
r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext)
assert.Nil(t, err)
assert.NotNil(t, r)
Expand Down Expand Up @@ -474,7 +473,7 @@ func TestBuildResourceDaskInterruptible(t *testing.T) {
daskResourceHandler := daskResourceHandler{}

taskTemplate := dummyDaskTaskTemplate("", nil, "")
taskContext := dummyDaskTaskContext(taskTemplate, &defaultResources, nil, true)
taskContext := dummyDaskTaskContext(taskTemplate, &defaultResources, nil, true, k8s.PluginState{})
r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext)
assert.Nil(t, err)
assert.NotNil(t, r)
Expand Down Expand Up @@ -508,7 +507,7 @@ func TestBuildResouceDaskUsePodTemplate(t *testing.T) {
flytek8s.DefaultPodTemplateStore.Store(podTemplate)
daskResourceHandler := daskResourceHandler{}
taskTemplate := dummyDaskTaskTemplate("", nil, podTemplateName)
taskContext := dummyDaskTaskContext(taskTemplate, &defaultResources, nil, false)
taskContext := dummyDaskTaskContext(taskTemplate, &defaultResources, nil, false, k8s.PluginState{})
r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext)
assert.Nil(t, err)
assert.NotNil(t, r)
Expand Down Expand Up @@ -628,7 +627,7 @@ func TestBuildResourceDaskExtendedResources(t *testing.T) {
t.Run(f.name, func(t *testing.T) {
taskTemplate := dummyDaskTaskTemplate("", nil, "")
taskTemplate.ExtendedResources = f.extendedResourcesBase
taskContext := dummyDaskTaskContext(taskTemplate, f.resources, f.extendedResourcesOverride, false)
taskContext := dummyDaskTaskContext(taskTemplate, f.resources, f.extendedResourcesOverride, false, k8s.PluginState{})
daskResourceHandler := daskResourceHandler{}
r, err := daskResourceHandler.BuildResource(context.TODO(), taskContext)
assert.Nil(t, err)
Expand Down Expand Up @@ -694,7 +693,7 @@ func TestBuildIdentityResourceDask(t *testing.T) {
}

taskTemplate := dummyDaskTaskTemplate("", nil, "")
taskContext := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, nil, false)
taskContext := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, nil, false, k8s.PluginState{})
identityResources, err := daskResourceHandler.BuildIdentityResource(context.TODO(), taskContext.TaskExecutionMetadata())
if err != nil {
panic(err)
Expand All @@ -707,7 +706,7 @@ func TestGetTaskPhaseDask(t *testing.T) {
ctx := context.TODO()

taskTemplate := dummyDaskTaskTemplate("", nil, "")
taskCtx := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, nil, false)
taskCtx := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, nil, false, k8s.PluginState{})

taskPhase, err := daskResourceHandler.GetTaskPhase(ctx, taskCtx, dummyDaskJob(""))
assert.NoError(t, err)
Expand Down Expand Up @@ -751,3 +750,21 @@ func TestGetTaskPhaseDask(t *testing.T) {
assert.NotNil(t, taskPhase.Info().Logs)
assert.Nil(t, err)
}

func TestGetTaskPhaseIncreasePhaseVersion(t *testing.T) {
daskResourceHandler := daskResourceHandler{}
ctx := context.TODO()

pluginState := k8s.PluginState{
Phase: pluginsCore.PhaseInitializing,
PhaseVersion: pluginsCore.DefaultPhaseVersion,
Reason: "task submitted to K8s",
}
taskTemplate := dummyDaskTaskTemplate("", nil, "")
taskCtx := dummyDaskTaskContext(taskTemplate, &v1.ResourceRequirements{}, nil, false, pluginState)

taskPhase, err := daskResourceHandler.GetTaskPhase(ctx, taskCtx, dummyDaskJob(daskAPI.DaskJobCreated))

assert.NoError(t, err)
assert.Equal(t, taskPhase.Version(), pluginsCore.DefaultPhaseVersion+1)
}
39 changes: 29 additions & 10 deletions flyteplugins/go/tasks/plugins/k8s/ray/ray_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -681,7 +681,7 @@ func TestInjectLogsSidecar(t *testing.T) {
}
}

func newPluginContext() k8s.PluginContext {
func newPluginContext(pluginState k8s.PluginState) k8s.PluginContext {
plg := &mocks2.PluginContext{}

taskExecID := &mocks.TaskExecutionID{}
Expand Down Expand Up @@ -709,11 +709,10 @@ func newPluginContext() k8s.PluginContext {
tskCtx.OnGetTaskExecutionID().Return(taskExecID)
plg.OnTaskExecutionMetadata().Return(tskCtx)

inputState := k8s.PluginState{}
pluginStateReaderMock := mocks.PluginStateReader{}
pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&inputState).String())).Return(
pluginStateReaderMock.On("Get", mock.AnythingOfType(reflect.TypeOf(&pluginState).String())).Return(
func(v interface{}) uint8 {
*(v.(*k8s.PluginState)) = inputState
*(v.(*k8s.PluginState)) = pluginState
return 0
},
func(v interface{}) error {
Expand All @@ -739,7 +738,7 @@ func init() {
func TestGetTaskPhase(t *testing.T) {
ctx := context.Background()
rayJobResourceHandler := rayJobResourceHandler{}
pluginCtx := newPluginContext()
pluginCtx := newPluginContext(k8s.PluginState{})

testCases := []struct {
rayJobPhase rayv1alpha1.JobStatus
Expand Down Expand Up @@ -783,7 +782,7 @@ func TestGetTaskPhase(t *testing.T) {
func TestGetTaskPhase_V1(t *testing.T) {
ctx := context.Background()
rayJobResourceHandler := rayJobResourceHandler{}
pluginCtx := newPluginContext()
pluginCtx := newPluginContext(k8s.PluginState{})

testCases := []struct {
rayJobPhase rayv1.JobStatus
Expand Down Expand Up @@ -824,8 +823,28 @@ func TestGetTaskPhase_V1(t *testing.T) {
}
}

func TestGetTaskPhaseIncreasePhaseVersion(t *testing.T) {
rayJobResourceHandler := rayJobResourceHandler{}

ctx := context.TODO()

pluginState := k8s.PluginState{
Phase: pluginsCore.PhaseInitializing,
PhaseVersion: pluginsCore.DefaultPhaseVersion,
Reason: "task submitted to K8s",
}
pluginCtx := newPluginContext(pluginState)

rayObject := &rayv1alpha1.RayJob{}
rayObject.Status.JobDeploymentStatus = rayv1alpha1.JobDeploymentStatusInitializing
phaseInfo, err := rayJobResourceHandler.GetTaskPhase(ctx, pluginCtx, rayObject)

assert.NoError(t, err)
assert.Equal(t, phaseInfo.Version(), pluginsCore.DefaultPhaseVersion+1)
}

func TestGetEventInfo_LogTemplates(t *testing.T) {
pluginCtx := newPluginContext()
pluginCtx := newPluginContext(k8s.PluginState{})
testCases := []struct {
name string
rayJob rayv1alpha1.RayJob
Expand Down Expand Up @@ -924,7 +943,7 @@ func TestGetEventInfo_LogTemplates(t *testing.T) {
}

func TestGetEventInfo_LogTemplates_V1(t *testing.T) {
pluginCtx := newPluginContext()
pluginCtx := newPluginContext(k8s.PluginState{})
testCases := []struct {
name string
rayJob rayv1.RayJob
Expand Down Expand Up @@ -1023,7 +1042,7 @@ func TestGetEventInfo_LogTemplates_V1(t *testing.T) {
}

func TestGetEventInfo_DashboardURL(t *testing.T) {
pluginCtx := newPluginContext()
pluginCtx := newPluginContext(k8s.PluginState{})
testCases := []struct {
name string
rayJob rayv1alpha1.RayJob
Expand Down Expand Up @@ -1075,7 +1094,7 @@ func TestGetEventInfo_DashboardURL(t *testing.T) {
}

func TestGetEventInfo_DashboardURL_V1(t *testing.T) {
pluginCtx := newPluginContext()
pluginCtx := newPluginContext(k8s.PluginState{})
testCases := []struct {
name string
rayJob rayv1.RayJob
Expand Down

0 comments on commit 374242b

Please sign in to comment.