Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: handle nil input more gracefully #486

Merged
merged 1 commit into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions cmd/hatchet-engine/engine/run.go
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ func Run(ctx context.Context, cf *loader.ConfigLoader) error {

if sc.HasService("workflowscontroller") {
wc, err := workflows.New(
workflows.WithAlerter(sc.Alerter),
workflows.WithMessageQueue(sc.MessageQueue),
workflows.WithRepository(sc.EngineRepository),
workflows.WithLogger(sc.Logger),
Expand Down
2 changes: 1 addition & 1 deletion internal/repository/prisma/get_group_key_run.go
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ func (s *getGroupKeyRunRepository) GetGroupKeyRunForEngine(ctx context.Context,
}

if len(res) == 0 {
return nil, nil
return nil, fmt.Errorf("could not find group key run %s", getGroupKeyRunId)
}

return res[0], nil
Expand Down
60 changes: 35 additions & 25 deletions internal/repository/workflow_run.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,10 @@ func GetCreateWorkflowRunOptsFromManual(
input []byte,
additionalMetadata map[string]interface{},
) (*CreateWorkflowRunOpts, error) {
if input == nil {
input = []byte("{}")
}

opts := &CreateWorkflowRunOpts{
DisplayName: StringPtr(getWorkflowRunDisplayName(workflowVersion.WorkflowName)),
WorkflowVersionId: sqlchelpers.UUIDToStr(workflowVersion.WorkflowVersion.ID),
Expand All @@ -86,11 +90,9 @@ func GetCreateWorkflowRunOptsFromManual(
AdditionalMetadata: additionalMetadata,
}

if input != nil {
if workflowVersion.ConcurrencyLimitStrategy.Valid {
opts.GetGroupKeyRun = &CreateGroupKeyRunOpts{
Input: input,
}
if workflowVersion.ConcurrencyLimitStrategy.Valid {
opts.GetGroupKeyRun = &CreateGroupKeyRunOpts{
Input: input,
}
}

Expand All @@ -105,6 +107,10 @@ func GetCreateWorkflowRunOptsFromParent(
childKey *string,
additionalMetadata map[string]interface{},
) (*CreateWorkflowRunOpts, error) {
if input == nil {
input = []byte("{}")
}

opts := &CreateWorkflowRunOpts{
DisplayName: StringPtr(getWorkflowRunDisplayName(workflowVersion.WorkflowName)),
WorkflowVersionId: sqlchelpers.UUIDToStr(workflowVersion.WorkflowVersion.ID),
Expand All @@ -116,11 +122,9 @@ func GetCreateWorkflowRunOptsFromParent(

WithParent(parentId, parentStepRunId, childIndex, childKey)(opts)

if input != nil {
if workflowVersion.ConcurrencyLimitStrategy.Valid {
opts.GetGroupKeyRun = &CreateGroupKeyRunOpts{
Input: input,
}
if workflowVersion.ConcurrencyLimitStrategy.Valid {
opts.GetGroupKeyRun = &CreateGroupKeyRunOpts{
Input: input,
}
}

Expand All @@ -133,6 +137,10 @@ func GetCreateWorkflowRunOptsFromEvent(
input []byte,
additionalMetadata map[string]interface{},
) (*CreateWorkflowRunOpts, error) {
if input == nil {
input = []byte("{}")
}

opts := &CreateWorkflowRunOpts{
DisplayName: StringPtr(getWorkflowRunDisplayName(workflowVersion.WorkflowName)),
WorkflowVersionId: sqlchelpers.UUIDToStr(workflowVersion.WorkflowVersion.ID),
Expand All @@ -142,11 +150,9 @@ func GetCreateWorkflowRunOptsFromEvent(
AdditionalMetadata: additionalMetadata,
}

if input != nil {
if workflowVersion.ConcurrencyLimitStrategy.Valid {
opts.GetGroupKeyRun = &CreateGroupKeyRunOpts{
Input: input,
}
if workflowVersion.ConcurrencyLimitStrategy.Valid {
opts.GetGroupKeyRun = &CreateGroupKeyRunOpts{
Input: input,
}
}

Expand All @@ -160,6 +166,10 @@ func GetCreateWorkflowRunOptsFromCron(
input []byte,
additionalMetadata map[string]interface{},
) (*CreateWorkflowRunOpts, error) {
if input == nil {
input = []byte("{}")
}

opts := &CreateWorkflowRunOpts{
DisplayName: StringPtr(getWorkflowRunDisplayName(workflowVersion.WorkflowName)),
WorkflowVersionId: sqlchelpers.UUIDToStr(workflowVersion.WorkflowVersion.ID),
Expand All @@ -170,11 +180,9 @@ func GetCreateWorkflowRunOptsFromCron(
AdditionalMetadata: additionalMetadata,
}

if input != nil {
if workflowVersion.ConcurrencyLimitStrategy.Valid {
opts.GetGroupKeyRun = &CreateGroupKeyRunOpts{
Input: input,
}
if workflowVersion.ConcurrencyLimitStrategy.Valid {
opts.GetGroupKeyRun = &CreateGroupKeyRunOpts{
Input: input,
}
}

Expand All @@ -188,6 +196,10 @@ func GetCreateWorkflowRunOptsFromSchedule(
additionalMetadata map[string]interface{},
fs ...CreateWorkflowRunOpt,
) (*CreateWorkflowRunOpts, error) {
if input == nil {
input = []byte("{}")
}

opts := &CreateWorkflowRunOpts{
DisplayName: StringPtr(getWorkflowRunDisplayName(workflowVersion.WorkflowName)),
WorkflowVersionId: sqlchelpers.UUIDToStr(workflowVersion.WorkflowVersion.ID),
Expand All @@ -197,11 +209,9 @@ func GetCreateWorkflowRunOptsFromSchedule(
AdditionalMetadata: additionalMetadata,
}

if input != nil {
if workflowVersion.ConcurrencyLimitStrategy.Valid {
opts.GetGroupKeyRun = &CreateGroupKeyRunOpts{
Input: input,
}
if workflowVersion.ConcurrencyLimitStrategy.Valid {
opts.GetGroupKeyRun = &CreateGroupKeyRunOpts{
Input: input,
}
}

Expand Down
23 changes: 13 additions & 10 deletions internal/services/controllers/jobs/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"github.com/hatchet-dev/hatchet/internal/repository/prisma/dbsqlc"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/sqlchelpers"
"github.com/hatchet-dev/hatchet/internal/services/shared/defaults"
"github.com/hatchet-dev/hatchet/internal/services/shared/recoveryutils"
"github.com/hatchet-dev/hatchet/internal/services/shared/tasktypes"
"github.com/hatchet-dev/hatchet/internal/telemetry"
"github.com/hatchet-dev/hatchet/internal/telemetry/servertel"
Expand Down Expand Up @@ -198,7 +199,17 @@ func (jc *JobsControllerImpl) Start() (func() error, error) {
return cleanup, nil
}

func (ec *JobsControllerImpl) handleTask(ctx context.Context, task *msgqueue.Message) error {
func (ec *JobsControllerImpl) handleTask(ctx context.Context, task *msgqueue.Message) (err error) {
defer func() {
if r := recover(); r != nil {
recoverErr := recoveryutils.RecoverWithAlert(ec.l, ec.a, r)

if recoverErr != nil {
err = recoverErr
}
}
}()

switch task.ID {
case "job-run-queued":
return ec.handleJobRunQueued(ctx, task)
Expand Down Expand Up @@ -1044,15 +1055,7 @@ func (ec *JobsControllerImpl) cancelStepRun(ctx context.Context, tenantId, stepR
func (ec *JobsControllerImpl) handleStepRunUpdateInfo(stepRun *dbsqlc.GetStepRunForEngineRow, updateInfo *repository.StepRunUpdateInfo) {
defer func() {
if r := recover(); r != nil {
err, ok := r.(error)

if !ok {
err = fmt.Errorf("%v", r)
}

ec.l.Error().Err(err).Msg("recovered from panic")

return
recoveryutils.RecoverWithAlert(ec.l, ec.a, r) // nolint:errcheck
}
}()

Expand Down
43 changes: 35 additions & 8 deletions internal/services/controllers/workflows/controller.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@ import (
"github.com/hatchet-dev/hatchet/internal/repository/prisma/db"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/dbsqlc"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/sqlchelpers"
"github.com/hatchet-dev/hatchet/internal/services/shared/recoveryutils"
"github.com/hatchet-dev/hatchet/internal/services/shared/tasktypes"
"github.com/hatchet-dev/hatchet/internal/telemetry"
hatcheterrors "github.com/hatchet-dev/hatchet/pkg/errors"
)

type WorkflowsController interface {
Expand All @@ -33,23 +35,28 @@ type WorkflowsControllerImpl struct {
dv datautils.DataDecoderValidator
s gocron.Scheduler
tenantAlerter *alerting.TenantAlertManager
a *hatcheterrors.Wrapped
}

type WorkflowsControllerOpt func(*WorkflowsControllerOpts)

type WorkflowsControllerOpts struct {
mq msgqueue.MessageQueue
l *zerolog.Logger
repo repository.EngineRepository
dv datautils.DataDecoderValidator
ta *alerting.TenantAlertManager
mq msgqueue.MessageQueue
l *zerolog.Logger
repo repository.EngineRepository
dv datautils.DataDecoderValidator
ta *alerting.TenantAlertManager
alerter hatcheterrors.Alerter
}

func defaultWorkflowsControllerOpts() *WorkflowsControllerOpts {
logger := logger.NewDefaultLogger("workflows-controller")
alerter := hatcheterrors.NoOpAlerter{}

return &WorkflowsControllerOpts{
l: &logger,
dv: datautils.NewDataDecoderValidator(),
l: &logger,
dv: datautils.NewDataDecoderValidator(),
alerter: alerter,
}
}

Expand All @@ -71,6 +78,12 @@ func WithRepository(r repository.EngineRepository) WorkflowsControllerOpt {
}
}

func WithAlerter(a hatcheterrors.Alerter) WorkflowsControllerOpt {
return func(opts *WorkflowsControllerOpts) {
opts.alerter = a
}
}

func WithDataDecoderValidator(dv datautils.DataDecoderValidator) WorkflowsControllerOpt {
return func(opts *WorkflowsControllerOpts) {
opts.dv = dv
Expand Down Expand Up @@ -111,13 +124,17 @@ func New(fs ...WorkflowsControllerOpt) (*WorkflowsControllerImpl, error) {
newLogger := opts.l.With().Str("service", "workflows-controller").Logger()
opts.l = &newLogger

a := hatcheterrors.NewWrapped(opts.alerter)
a.WithData(map[string]interface{}{"service": "workflows-controller"})

return &WorkflowsControllerImpl{
mq: opts.mq,
l: opts.l,
repo: opts.repo,
dv: opts.dv,
s: s,
tenantAlerter: opts.ta,
a: a,
}, nil
}

Expand Down Expand Up @@ -193,7 +210,17 @@ func (wc *WorkflowsControllerImpl) Start() (func() error, error) {
return cleanup, nil
}

func (wc *WorkflowsControllerImpl) handleTask(ctx context.Context, task *msgqueue.Message) error {
func (wc *WorkflowsControllerImpl) handleTask(ctx context.Context, task *msgqueue.Message) (err error) {
defer func() {
if r := recover(); r != nil {
recoverErr := recoveryutils.RecoverWithAlert(wc.l, wc.a, r)

if recoverErr != nil {
err = recoverErr
}
}
}()

switch task.ID {
case "workflow-run-queued":
return wc.handleWorkflowRunQueued(ctx, task)
Expand Down
4 changes: 3 additions & 1 deletion internal/services/controllers/workflows/queue.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ func (wc *WorkflowsControllerImpl) handleWorkflowRunQueued(ctx context.Context,

// determine if we should start this workflow run or we need to limit its concurrency
// if the workflow has concurrency settings, then we need to check if we can start it
if workflowRun.ConcurrencyLimitStrategy.Valid {
if workflowRun.ConcurrencyLimitStrategy.Valid && workflowRun.GetGroupKeyRunId.Valid {
wc.l.Info().Msgf("workflow %s has concurrency settings", workflowRunId)

groupKeyRunId := sqlchelpers.UUIDToStr(workflowRun.GetGroupKeyRunId)
Expand All @@ -76,6 +76,8 @@ func (wc *WorkflowsControllerImpl) handleWorkflowRunQueued(ctx context.Context,
}

return nil
} else if workflowRun.ConcurrencyLimitStrategy.Valid && !workflowRun.GetGroupKeyRunId.Valid {
return fmt.Errorf("workflow run %s has concurrency settings but no group key run", workflowRunId)
}

err = wc.queueWorkflowRunJobs(ctx, workflowRun)
Expand Down
7 changes: 6 additions & 1 deletion internal/services/dispatcher/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/hatchet-dev/hatchet/internal/repository"
"github.com/hatchet-dev/hatchet/internal/repository/prisma/sqlchelpers"
"github.com/hatchet-dev/hatchet/internal/services/dispatcher/contracts"
"github.com/hatchet-dev/hatchet/internal/services/shared/recoveryutils"
"github.com/hatchet-dev/hatchet/internal/services/shared/tasktypes"
"github.com/hatchet-dev/hatchet/internal/telemetry"
"github.com/hatchet-dev/hatchet/internal/telemetry/servertel"
Expand Down Expand Up @@ -303,7 +304,11 @@ func (d *DispatcherImpl) Start() (func() error, error) {
func (d *DispatcherImpl) handleTask(ctx context.Context, task *msgqueue.Message) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("recovered from panic: %v", r)
recoverErr := recoveryutils.RecoverWithAlert(d.l, d.a, r)

if recoverErr != nil {
err = recoverErr
}
}
}()

Expand Down
29 changes: 29 additions & 0 deletions internal/services/shared/recoveryutils/recover.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package recoveryutils

import (
"fmt"
"runtime/debug"

"github.com/rs/zerolog"

hatcheterrors "github.com/hatchet-dev/hatchet/pkg/errors"
)

func RecoverWithAlert(l *zerolog.Logger, a *hatcheterrors.Wrapped, r any) error {
var ok bool
err, ok := r.(error)

if !ok {
err = fmt.Errorf("%v", r)
}

err = fmt.Errorf("recovered from panic: %w. Stack trace:\n%s", err, string(debug.Stack()))

l.Error().Err(err).Msgf("recovered from panic")

if a != nil {
return a.WrapErr(err, nil)
}

return err
}