core/internal/watermill/watermilltest/watermilltest.go

85 lines
1.8 KiB
Go

package watermilltest
import (
"context"
"time"
"github.com/ThreeDotsLabs/watermill"
"github.com/ThreeDotsLabs/watermill/message"
"github.com/ThreeDotsLabs/watermill/pubsub/gochannel"
"github.com/stretchr/testify/require"
)
type Registerer interface {
Register(router *message.Router)
}
func RunRouter(tb TestingTB, rs ...Registerer) *message.Router {
tb.Helper()
return RunRouterWithContext(tb, context.Background(), rs...)
}
const routerCloseTimeout = 10 * time.Second
//nolint:revive
func RunRouterWithContext(tb TestingTB, ctx context.Context, rs ...Registerer) *message.Router {
tb.Helper()
router, err := message.NewRouter(message.RouterConfig{CloseTimeout: routerCloseTimeout}, watermill.NopLogger{})
require.NoError(tb, err)
tb.Cleanup(func() {
_ = router.Close()
})
for _, r := range rs {
r.Register(router)
}
go func() {
require.NoError(tb, router.Run(ctx))
}()
<-router.Running()
return router
}
type MiddlewareRegisterer struct {
H message.HandlerMiddleware
}
var _ Registerer = MiddlewareRegisterer{}
func (m MiddlewareRegisterer) Register(router *message.Router) {
router.AddMiddleware(m.H)
}
func NewWaitForHandlerMiddleware(handlerName string) (MiddlewareRegisterer, <-chan struct{}) {
ch := make(chan struct{})
return MiddlewareRegisterer{
H: func(h message.HandlerFunc) message.HandlerFunc {
return func(msg *message.Message) ([]*message.Message, error) {
if message.HandlerNameFromCtx(msg.Context()) != handlerName {
return h(msg)
}
result, err := h(msg)
if len(result) == 0 && err == nil {
close(ch)
}
return result, err
}
},
}, ch
}
func NewPubSub(tb TestingTB) *gochannel.GoChannel {
tb.Helper()
pubSub := gochannel.NewGoChannel(gochannel.Config{Persistent: true}, watermill.NopLogger{})
tb.Cleanup(func() {
_ = pubSub.Close()
})
return pubSub
}