Refactor: GoroutineTracker with unnecessary usage of reflect

Today, I encountered this code in my company codebase (the code and comments are rewritten for demo purpose and do not include any proprietary code):

type GoroutineTracker struct {
  wg sync.WaitGroup
  // ... some other fields
}
// Go starts a new goroutine and tracks it with some metrics.
func (g *GoroutineTracker) Go(ctx context.Context, name string, f any, args ...any) {
	fn := reflect.TypeOf(f)
	if fn.Kind() != reflect.Func { panic("must be function") }
	if fn.NumIn() != len(args) { panic("args does not match fn signature") }
	if fn.NumOut() > 0 { panic("output from fn is ignored") }
	
	g.wg.Add(1)
	id := g.startCaptureTime()
	go func() {
		defer func() {
			r := recover()
			// ... some panic handling code
			g.wg.Done()
			g.endCaptureTime(id)
		}()
		
		input := typez.MapFunc(args, func(arg any) reflect.Value {
			return reflect.ValueOf(arg)
		})
		_ = reflect.ValueOf(f).Call(input)
	}()
}
// Wait for all goroutines to finished.
func (g *GoroutineTracker) Wait() { g.wg.Wait() }

The GoroutineTracker is used for tracking usage of goroutines in the codebase, for example, number of goroutines, time taken by each goroutine, etc. The Go method is used to start a new goroutine and track it. The Wait method is used to wait for all goroutines to finish.

Example of usage:

g := NewGoroutineTracker()
g.Go(ctx, "task1", doTask1, arg1, arg2)
g.Go(ctx, "task2", doTask2, arg3)
g.Wait()

Problem: The usage of reflect is unnecessary and can be avoided

Well, that code works, but it uses the reflect package to check the function signature then call the function. It’s totally unnecessary, and we can avoid it by changing the usage to:

g := NewGoroutineTracker()
g.Go(ctx, "task1", func() error {
	return doTask1(arg1, arg2)
})
g.Go(ctx, "task2", func() error {
	return doTask2(arg3)
})

The new code will be simpler and has many benefits:

  • Type safety: No need to check the function signature using reflect. The compiler will do it for us. The original code has potential runtime errors if the function signature does not match the arguments.
  • Error handling: We can return an error from the function and handle it in the caller. The original code ignores the output of the function.
  • Readability: The new code is more readable and easier to understand. We can see the function signature and arguments directly in the code.

A better implementation of GoroutineTracker

Here is the refactored code:

func (g *GoroutineTracker) Go(ctx context.Context, fn func() error) {
	g.wg.Add(1)
	id := g.startCaptureTime()
	go func() (err error) {
		defer func() {
			r := recover()
			// capture returned error and panic
			g.endCaptureTime(id, r, err)
			g.wg.Done()
		}()
		// just call the function, no reflect needed
		return fn()
	}()
}

Wait for all goroutines to finish before shutting down

Another use case for GoroutineTracker is for waiting all goroutines to finish before shutting down the application. So we can have 2 types of waiting:

  • In a function: Waiting for all local goroutines to finish.
  • When application shutdown: Waiting for all goroutines that started by any GoroutineTracker to finish.

We can implement it by adding a global tracker and making any tracker register its function to the global tracker:

type GlobalTracker struct {
	wg sync.WaitGroup
	// ... some other fields
}
type GoroutineTracker struct {
	parent *GlobalTracker
	wg sync.WaitGroup
	// ... some other fields
}
func (g *GlobalTracker) New() *GoroutineTracker {
	return &GoroutineTracker{parent: g}
}
func (g *GoroutineTracker) Go(ctx context.Context, fn func() error) {
	g.wg.Add(1)            // use both parent and local wg
	g.parent.wg.Add(1)     //   to track the new goroutine
	id := g.startCaptureTime()
	go func() (err error) {
		defer func() {
			// ...
			g.endCaptureTime(id, r, err)
			g.wg.Done()
			g.parent.wg.Done()
		}()
		
		return fn()
	}()
}
func (g *GlobalTracker) WaitForAll() { g.wg.Wait() }
func (g *GoroutineTracker) Wait()    { g.wg.Wait() }

And we can use WaitForAll() to wait for all goroutines to finish before shutting down the application:

type FooService {
	tracker *GlobalTracker
	// ... some other fields
}
func (s *FooService) DoSomething(ctx context.Context) {
	g := s.tracker.New()
	g.Go(ctx, func() error { return s.doTask1(arg1, arg2) })
	g.Go(ctx, func() error { return s.doTask2(arg3) })
	g.Wait()     // wait for local goroutines, this is optional
}

func main() {
	// some initialization, then start the application
	globalTracker := &GlobalTracker{}
	fooService := FooService{tracker: globalTracker, /*...*/}
	application.Start()
	
	// wait for all goroutines to finish before shutting down
	<-application.Done()
	globalTracker.Wait()
}

Conclusion

In conclusion, while the original implementation of GoroutineTracker works and can track goroutines, its use of the reflect package to dynamically check and call functions introduces unnecessary complexity and potential runtime errors. By refactoring the code to directly accept function literals, we achieve improved type safety, streamlined error handling, and enhanced readability. This approach leverages Go’s compiler-checked type system to ensure compatibility between function signatures and arguments, resulting in more robust and maintainable code. By adopting these changes, we optimize the GoroutineTracker for clarity and reliability, aligning with best practices in Go programming.

Author

I'm Oliver Nguyen. A software maker working mostly in Go and JavaScript. I enjoy learning and seeing a better version of myself each day. Occasionally spin off new open source projects. Share knowledge and thoughts during my journey. Connect with me on , , , and .

Back Back