Skip to content

Commit

Permalink
GH-38795: [Go] Fix race GetToTimeFunc for Timestamp (#38797)
Browse files Browse the repository at this point in the history
### Rationale for this change
Adding RWMutex to protect `loc` in `TimestampType` and fix the race condition.

### Are these changes tested?
Yes, a unit test is added which is covered by the CI which runs with `-race`.

### Are there any user-facing changes?
Copying `TimestampType` will now be problematic and linters will show it as an error. In theory this shouldn't be a problem as most uses of TimestampType should be utilizing pointers to it rather than the value directly.

* Closes: #38795

Lead-authored-by: Matt Topol <[email protected]>
Co-authored-by: Benjamin Kietzman <[email protected]>
Co-authored-by: Ben Harkins <[email protected]>
Signed-off-by: Benjamin Kietzman <[email protected]>
  • Loading branch information
3 people authored Nov 27, 2023
1 parent b1f1ef4 commit 62e1e9a
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 28 deletions.
53 changes: 26 additions & 27 deletions go/arrow/compare_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,13 @@ func TestTypeEqual(t *testing.T) {
fields: []Field{
{Name: "f1", Type: PrimitiveTypes.Uint16, Nullable: true},
},
index: map[string][]int{"f1": []int{0}},
index: map[string][]int{"f1": {0}},
},
&StructType{
fields: []Field{
{Name: "f1", Type: PrimitiveTypes.Uint32, Nullable: true},
},
index: map[string][]int{"f1": []int{0}},
index: map[string][]int{"f1": {0}},
},
false, true,
},
Expand All @@ -131,13 +131,13 @@ func TestTypeEqual(t *testing.T) {
fields: []Field{
{Name: "f1", Type: PrimitiveTypes.Uint32, Nullable: false},
},
index: map[string][]int{"f1": []int{0}},
index: map[string][]int{"f1": {0}},
},
&StructType{
fields: []Field{
{Name: "f1", Type: PrimitiveTypes.Uint32, Nullable: true},
},
index: map[string][]int{"f1": []int{0}},
index: map[string][]int{"f1": {0}},
},
false, false,
},
Expand All @@ -146,13 +146,13 @@ func TestTypeEqual(t *testing.T) {
fields: []Field{
{Name: "f0", Type: PrimitiveTypes.Uint32, Nullable: true},
},
index: map[string][]int{"f0": []int{0}},
index: map[string][]int{"f0": {0}},
},
&StructType{
fields: []Field{
{Name: "f1", Type: PrimitiveTypes.Uint32, Nullable: true},
},
index: map[string][]int{"f1": []int{0}},
index: map[string][]int{"f1": {0}},
},
false, false,
},
Expand All @@ -161,14 +161,14 @@ func TestTypeEqual(t *testing.T) {
fields: []Field{
{Name: "f1", Type: PrimitiveTypes.Uint32, Nullable: true},
},
index: map[string][]int{"f1": []int{0}},
index: map[string][]int{"f1": {0}},
},
&StructType{
fields: []Field{
{Name: "f1", Type: PrimitiveTypes.Uint32, Nullable: true},
{Name: "f2", Type: PrimitiveTypes.Uint32, Nullable: true},
},
index: map[string][]int{"f1": []int{0}, "f2": []int{1}},
index: map[string][]int{"f1": {0}, "f2": {1}},
},
false, true,
},
Expand All @@ -177,14 +177,14 @@ func TestTypeEqual(t *testing.T) {
fields: []Field{
{Name: "f1", Type: PrimitiveTypes.Uint32, Nullable: true},
},
index: map[string][]int{"f1": []int{0}},
index: map[string][]int{"f1": {0}},
},
&StructType{
fields: []Field{
{Name: "f1", Type: PrimitiveTypes.Uint32, Nullable: true},
{Name: "f2", Type: PrimitiveTypes.Uint32, Nullable: true},
},
index: map[string][]int{"f1": []int{0}, "f2": []int{1}},
index: map[string][]int{"f1": {0}, "f2": {1}},
},
false, false,
},
Expand All @@ -193,13 +193,13 @@ func TestTypeEqual(t *testing.T) {
fields: []Field{
{Name: "f1", Type: PrimitiveTypes.Uint32, Nullable: true},
},
index: map[string][]int{"f1": []int{0}},
index: map[string][]int{"f1": {0}},
},
&StructType{
fields: []Field{
{Name: "f2", Type: PrimitiveTypes.Uint32, Nullable: true},
},
index: map[string][]int{"f2": []int{0}},
index: map[string][]int{"f2": {0}},
},
false, false,
},
Expand All @@ -209,14 +209,14 @@ func TestTypeEqual(t *testing.T) {
{Name: "f1", Type: PrimitiveTypes.Uint16, Nullable: true},
{Name: "f2", Type: PrimitiveTypes.Float32, Nullable: false},
},
index: map[string][]int{"f1": []int{0}, "f2": []int{1}},
index: map[string][]int{"f1": {0}, "f2": {1}},
},
&StructType{
fields: []Field{
{Name: "f1", Type: PrimitiveTypes.Uint16, Nullable: true},
{Name: "f2", Type: PrimitiveTypes.Float32, Nullable: false},
},
index: map[string][]int{"f1": []int{0}, "f2": []int{1}},
index: map[string][]int{"f1": {0}, "f2": {1}},
},
true, false,
},
Expand All @@ -226,14 +226,14 @@ func TestTypeEqual(t *testing.T) {
{Name: "f1", Type: PrimitiveTypes.Uint16, Nullable: true},
{Name: "f2", Type: PrimitiveTypes.Float32, Nullable: false},
},
index: map[string][]int{"f1": []int{0}, "f2": []int{1}},
index: map[string][]int{"f1": {0}, "f2": {1}},
},
&StructType{
fields: []Field{
{Name: "f1", Type: PrimitiveTypes.Uint16, Nullable: true},
{Name: "f2", Type: PrimitiveTypes.Float32, Nullable: false},
},
index: map[string][]int{"f1": []int{0}, "f2": []int{1}},
index: map[string][]int{"f1": {0}, "f2": {1}},
},
true, false,
},
Expand All @@ -243,15 +243,15 @@ func TestTypeEqual(t *testing.T) {
{Name: "f1", Type: PrimitiveTypes.Uint16, Nullable: true},
{Name: "f2", Type: PrimitiveTypes.Float32, Nullable: false},
},
index: map[string][]int{"f1": []int{0}, "f2": []int{1}},
index: map[string][]int{"f1": {0}, "f2": {1}},
meta: MetadataFrom(map[string]string{"k1": "v1", "k2": "v2"}),
},
&StructType{
fields: []Field{
{Name: "f1", Type: PrimitiveTypes.Uint16, Nullable: true},
{Name: "f2", Type: PrimitiveTypes.Float32, Nullable: false},
},
index: map[string][]int{"f1": []int{0}, "f2": []int{1}},
index: map[string][]int{"f1": {0}, "f2": {1}},
meta: MetadataFrom(map[string]string{"k2": "v2", "k1": "v1"}),
},
true, true,
Expand All @@ -261,14 +261,14 @@ func TestTypeEqual(t *testing.T) {
fields: []Field{
{Name: "f1", Type: PrimitiveTypes.Uint32, Nullable: true},
},
index: map[string][]int{"f1": []int{0}},
index: map[string][]int{"f1": {0}},
meta: MetadataFrom(map[string]string{"k1": "v1"}),
},
&StructType{
fields: []Field{
{Name: "f1", Type: PrimitiveTypes.Uint32, Nullable: true},
},
index: map[string][]int{"f1": []int{0}},
index: map[string][]int{"f1": {0}},
meta: MetadataFrom(map[string]string{"k1": "v2"}),
},
true, false,
Expand All @@ -279,14 +279,14 @@ func TestTypeEqual(t *testing.T) {
{Name: "f1", Type: PrimitiveTypes.Uint16, Nullable: true, Metadata: MetadataFrom(map[string]string{"k1": "v1"})},
{Name: "f2", Type: PrimitiveTypes.Float32, Nullable: false},
},
index: map[string][]int{"f1": []int{0}, "f2": []int{1}},
index: map[string][]int{"f1": {0}, "f2": {1}},
},
&StructType{
fields: []Field{
{Name: "f1", Type: PrimitiveTypes.Uint16, Nullable: true, Metadata: MetadataFrom(map[string]string{"k1": "v2"})},
{Name: "f2", Type: PrimitiveTypes.Float32, Nullable: false},
},
index: map[string][]int{"f1": []int{0}, "f2": []int{1}},
index: map[string][]int{"f1": {0}, "f2": {1}},
},
false, true,
},
Expand All @@ -296,14 +296,14 @@ func TestTypeEqual(t *testing.T) {
{Name: "f1", Type: PrimitiveTypes.Uint16, Nullable: true},
{Name: "f1", Type: PrimitiveTypes.Uint32, Nullable: true},
},
index: map[string][]int{"f1": []int{0, 1}},
index: map[string][]int{"f1": {0, 1}},
},
&StructType{
fields: []Field{
{Name: "f1", Type: PrimitiveTypes.Uint16, Nullable: true},
{Name: "f1", Type: PrimitiveTypes.Uint32, Nullable: true},
},
index: map[string][]int{"f1": []int{0, 1}},
index: map[string][]int{"f1": {0, 1}},
},
true, true,
},
Expand All @@ -313,14 +313,14 @@ func TestTypeEqual(t *testing.T) {
{Name: "f1", Type: PrimitiveTypes.Uint32, Nullable: true},
{Name: "f1", Type: PrimitiveTypes.Uint16, Nullable: true},
},
index: map[string][]int{"f1": []int{0, 1}},
index: map[string][]int{"f1": {0, 1}},
},
&StructType{
fields: []Field{
{Name: "f1", Type: PrimitiveTypes.Uint16, Nullable: true},
{Name: "f1", Type: PrimitiveTypes.Uint32, Nullable: true},
},
index: map[string][]int{"f1": []int{0, 1}},
index: map[string][]int{"f1": {0, 1}},
},
false, true,
},
Expand All @@ -343,7 +343,6 @@ func TestTypeEqual(t *testing.T) {
MapOf(BinaryTypes.String, &TimestampType{
Unit: 0,
TimeZone: "UTC",
loc: nil,
}),
true, false,
},
Expand Down
16 changes: 15 additions & 1 deletion go/arrow/datatype_fixedwidth.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package arrow
import (
"fmt"
"strconv"
"sync"
"time"

"github.com/apache/arrow/go/v15/internal/json"
Expand Down Expand Up @@ -354,6 +355,7 @@ type TimestampType struct {
TimeZone string

loc *time.Location
mx sync.RWMutex
}

func (*TimestampType) ID() Type { return TIMESTAMP }
Expand Down Expand Up @@ -386,6 +388,8 @@ func (t *TimestampType) TimeUnit() TimeUnit { return t.Unit }
// This should be called if you change the value of the TimeZone after having
// potentially called GetZone.
func (t *TimestampType) ClearCachedLocation() {
t.mx.Lock()
defer t.mx.Unlock()
t.loc = nil
}

Expand All @@ -398,10 +402,20 @@ func (t *TimestampType) ClearCachedLocation() {
// so if you change the value of TimeZone after calling this, make sure to call
// ClearCachedLocation.
func (t *TimestampType) GetZone() (*time.Location, error) {
t.mx.RLock()
if t.loc != nil {
defer t.mx.RUnlock()
return t.loc, nil
}

t.mx.RUnlock()
t.mx.Lock()
defer t.mx.Unlock()
// in case GetZone() was called in between releasing the read lock and
// getting the write lock
if t.loc != nil {
return t.loc, nil
}
// the TimeZone string is allowed to be either a valid tzdata string
// such as "America/New_York" or an absolute offset of the form -XX:XX
// or +XX:XX
Expand All @@ -415,7 +429,7 @@ func (t *TimestampType) GetZone() (*time.Location, error) {

if loc, err := time.LoadLocation(t.TimeZone); err == nil {
t.loc = loc
return t.loc, err
return loc, err
}

// at this point we know that the timezone isn't empty, and didn't match
Expand Down
25 changes: 25 additions & 0 deletions go/arrow/datatype_fixedwidth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package arrow_test

import (
"sync"
"testing"
"time"

Expand Down Expand Up @@ -180,6 +181,30 @@ func TestTimestampType_GetToTimeFunc(t *testing.T) {
assert.Equal(t, "2345-12-29T19:00:00-05:00", toTimeNY(ts).Format(time.RFC3339))
}

// Test race condition from GH-38795
func TestGetToTimeFuncRace(t *testing.T) {
var (
wg sync.WaitGroup
w = make(chan bool)
routineNum = 10
)

wg.Add(routineNum)
for i := 0; i < routineNum; i++ {
go func() {
defer wg.Done()

<-w

_, _ = arrow.FixedWidthTypes.Timestamp_s.(*arrow.TimestampType).GetToTimeFunc()
}()
}

close(w)

wg.Wait()
}

func TestTime32Type(t *testing.T) {
for _, tc := range []struct {
unit arrow.TimeUnit
Expand Down

0 comments on commit 62e1e9a

Please sign in to comment.