package watermilltest import ( "context" "time" "github.com/ThreeDotsLabs/watermill" "github.com/ThreeDotsLabs/watermill/message" "github.com/ThreeDotsLabs/watermill/pubsub/gochannel" "github.com/stretchr/testify/require" ) type Registerer interface { Register(router *message.Router) } func RunRouter(tb TestingTB, rs ...Registerer) *message.Router { tb.Helper() return RunRouterWithContext(tb, context.Background(), rs...) } const routerCloseTimeout = 10 * time.Second //nolint:revive func RunRouterWithContext(tb TestingTB, ctx context.Context, rs ...Registerer) *message.Router { tb.Helper() router, err := message.NewRouter(message.RouterConfig{CloseTimeout: routerCloseTimeout}, watermill.NopLogger{}) require.NoError(tb, err) tb.Cleanup(func() { _ = router.Close() }) for _, r := range rs { r.Register(router) } go func() { require.NoError(tb, router.Run(ctx)) }() <-router.Running() return router } type MiddlewareRegisterer struct { H message.HandlerMiddleware } var _ Registerer = MiddlewareRegisterer{} func (m MiddlewareRegisterer) Register(router *message.Router) { router.AddMiddleware(m.H) } func NewWaitForHandlerMiddleware(handlerName string) (MiddlewareRegisterer, <-chan struct{}) { ch := make(chan struct{}) return MiddlewareRegisterer{ H: func(h message.HandlerFunc) message.HandlerFunc { return func(msg *message.Message) ([]*message.Message, error) { if message.HandlerNameFromCtx(msg.Context()) != handlerName { return h(msg) } result, err := h(msg) if len(result) == 0 && err == nil { close(ch) } return result, err } }, }, ch } func NewPubSub(tb TestingTB) *gochannel.GoChannel { tb.Helper() pubSub := gochannel.NewGoChannel(gochannel.Config{Persistent: true}, watermill.NopLogger{}) tb.Cleanup(func() { _ = pubSub.Close() }) return pubSub }