diff --git a/README.md b/README.md index 127b6b2..e263ef8 100644 --- a/README.md +++ b/README.md @@ -33,12 +33,39 @@ type Client func (c *Client) InsertBatch(ctx context.Context, b Bucket, items []BatchInsertItem) error func (c *Client) InsertItem(ctx context.Context, b Bucket, pk string, sk string, ct CausalityToken, item []byte) error func (c *Client) PollItem(ctx context.Context, b Bucket, pk string, sk string, ct CausalityToken, timeout time.Duration) (Item, CausalityToken, error) + func (c *Client) PollRange(ctx context.Context, b Bucket, pk string, q PollRangeQuery, timeout time.Duration) (*PollRangeResponse, error) func (c *Client) ReadBatch(ctx context.Context, b Bucket, q []ReadBatchSearch) ([]BatchSearchResult, error) func (c *Client) ReadIndex(ctx context.Context, b Bucket, q ReadIndexQuery) (*ReadIndexResponse, error) func (c *Client) ReadItemMulti(ctx context.Context, b Bucket, pk string, sk string) ([]Item, CausalityToken, error) func (c *Client) ReadItemSingle(ctx context.Context, b Bucket, pk string, sk string) (Item, CausalityToken, error) ``` +## Scrolling (Client-side / Go API) +To handle iteration in the K2V API, helper functions for simple cases are provided. + +For example, to perform a bulk search: +```go +handleBatch := func(result *k2v.BatchSearchResult) error { + log.Println(result.Items) + return nil +} +err := k2v.ScrollBatchSearch(ctx, f.cli, f.bucket, []k2v.BatchSearch{ + { + PartitionKey: "pk1", + }, + { + PartitionKey: "pk2", + Limit: 1, + }, +}, handleBatch) +``` + +This will repeatedly make calls to **ReadBatch** (batch search), using `nextStart` from the responses to generate subsequent requests until all queries are exhausted. + +See `ScrollIndex(ctx context.Context, client IndexScroller, b Bucket, query ReadIndexQuery, fn ReadIndexResponseHandler) error` for the equivalent for batch index reads. + +No helper is available for `PollRange()` yet. + ## Integration Tests ```shell K2V_ENDPOINT="http://[::1]:3904" \ diff --git a/client.go b/client.go index 8b63909..74f24ad 100644 --- a/client.go +++ b/client.go @@ -19,9 +19,13 @@ import ( const CausalityTokenHeader = "X-Garage-Causality-Token" +const payloadHashEmpty = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" +const payloadHashUnsigned = "UNSIGNED-PAYLOAD" + var TombstoneItemErr = errors.New("item is a tombstone") var NoSuchItemErr = errors.New("item does not exist") var ConcurrentItemsErr = errors.New("item has multiple concurrent values") +var NotModifiedTimeoutErr = errors.New("not modified within timeout") var awsSigner = v4.NewSigner() @@ -94,7 +98,7 @@ func (c *Client) executeRequest(req *http.Request) (*http.Response, error) { return nil, err } - resp, err := http.DefaultClient.Do(req) + resp, err := c.httpClient.Do(req) if err != nil { return nil, err } @@ -103,14 +107,25 @@ func (c *Client) executeRequest(req *http.Request) (*http.Response, error) { } func (c *Client) signRequest(req *http.Request) error { + if c.key.ID == "" || c.key.Secret == "" { + return errors.New("no credentials provided") + } + creds := aws.Credentials{ AccessKeyID: c.key.ID, SecretAccessKey: c.key.Secret, } - const noBody = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" - req.Header.Set("X-Amz-Content-Sha256", noBody) - err := awsSigner.SignHTTP(req.Context(), creds, req, noBody, "k2v", "garage", time.Now()) + var payloadHash string + if req.Body == nil || req.Body == http.NoBody { + + payloadHash = payloadHashEmpty + } else { + payloadHash = payloadHashUnsigned + } + req.Header.Set("X-Amz-Content-Sha256", payloadHash) + + err := awsSigner.SignHTTP(req.Context(), creds, req, payloadHash, "k2v", "garage", time.Now()) if err != nil { return err } @@ -143,14 +158,14 @@ type ReadIndexResponsePartitionKey struct { } type ReadIndexResponse struct { - Prefix any `json:"prefix"` - Start any `json:"start"` - End any `json:"end"` - Limit any `json:"limit"` + Prefix *string `json:"prefix"` + Start *string `json:"start"` + End *string `json:"end"` + Limit *int `json:"limit"` Reverse bool `json:"reverse"` PartitionKeys []ReadIndexResponsePartitionKey `json:"partitionKeys"` More bool `json:"more"` - NextStart any `json:"nextStart"` + NextStart *string `json:"nextStart"` } func (c *Client) ReadIndex(ctx context.Context, b Bucket, q ReadIndexQuery) (*ReadIndexResponse, error) { @@ -247,9 +262,6 @@ func (c *Client) ReadItemMulti(ctx context.Context, b Bucket, pk string, sk stri return []Item{body}, ct, nil case "application/json": var items []Item - if err != nil { - return nil, "", err - } if err := json.Unmarshal(body, &items); err != nil { return nil, "", err } @@ -300,6 +312,8 @@ func (c *Client) readItemSingle(ctx context.Context, b Bucket, pk string, sk str return nil, "", NoSuchItemErr case http.StatusConflict: return nil, ct, ConcurrentItemsErr + case http.StatusNotModified: + return nil, "", NotModifiedTimeoutErr default: return nil, "", fmt.Errorf("http status code %d", resp.StatusCode) } diff --git a/client_test.go b/client_test.go index 512da64..d7feaf1 100644 --- a/client_test.go +++ b/client_test.go @@ -1,16 +1,12 @@ package k2v_test import ( + k2v "code.notaphish.fyi/milas/garage-k2v-go" "context" + "github.com/stretchr/testify/require" "math/rand/v2" - "net/http/httptrace" "strconv" "testing" - "time" - - k2v "code.notaphish.fyi/milas/garage-k2v-go" - "github.com/stretchr/testify/require" - "go.uber.org/goleak" ) type fixture struct { @@ -22,11 +18,6 @@ type fixture struct { func newFixture(t testing.TB) (*fixture, context.Context) { t.Helper() - - t.Cleanup(func() { - goleak.VerifyNone(t) - }) - ctx := testContext(t) cli := k2v.NewClient(k2v.EndpointFromEnv(), k2v.KeyFromEnv()) @@ -36,7 +27,7 @@ func newFixture(t testing.TB) (*fixture, context.Context) { t: t, ctx: ctx, cli: cli, - bucket: k2v.Bucket("k2v-test"), + bucket: TestBucket, } return f, ctx @@ -48,23 +39,34 @@ func testContext(t testing.TB) context.Context { return ctx } -func randomKey() string { - return "key-" + strconv.Itoa(rand.IntN(1000000)) +func randomKey(prefix string) string { + return prefix + "-" + strconv.Itoa(rand.IntN(1000000)) +} + +func randomPk() string { + return randomKey("pk") +} + +func randomSk() string { + return randomKey("sk") } func TestClient_InsertItem(t *testing.T) { + t.Parallel() f, ctx := newFixture(t) - err := f.cli.InsertItem(ctx, f.bucket, randomKey(), randomKey(), "", []byte("hello")) + err := f.cli.InsertItem(ctx, f.bucket, randomPk(), randomSk(), "", []byte("hello")) require.NoError(t, err) } func TestClient_ReadItemNotExist(t *testing.T) { + t.Parallel() f, ctx := newFixture(t) - pk := randomKey() - sk := randomKey() + pk := randomPk() + sk := randomSk() t.Run("Single", func(t *testing.T) { + t.Parallel() item, ct, err := f.cli.ReadItemSingle(ctx, f.bucket, pk, sk) require.ErrorIs(t, err, k2v.NoSuchItemErr) require.Nil(t, item) @@ -72,6 +74,7 @@ func TestClient_ReadItemNotExist(t *testing.T) { }) t.Run("Multi", func(t *testing.T) { + t.Parallel() items, ct, err := f.cli.ReadItemMulti(ctx, f.bucket, pk, sk) require.ErrorIs(t, err, k2v.NoSuchItemErr) require.Empty(t, items) @@ -80,10 +83,11 @@ func TestClient_ReadItemNotExist(t *testing.T) { } func TestClient_ReadItemTombstone(t *testing.T) { + t.Parallel() f, ctx := newFixture(t) - pk := randomKey() - sk := randomKey() + pk := randomPk() + sk := randomSk() t.Logf("Creating item: PK=%s, SK=%s", pk, sk) @@ -110,15 +114,17 @@ func TestClient_ReadItemTombstone(t *testing.T) { } func TestClient_ReadItemSingleRevision(t *testing.T) { + t.Parallel() f, ctx := newFixture(t) - pk := randomKey() - sk := randomKey() + pk := randomPk() + sk := randomSk() err := f.cli.InsertItem(ctx, f.bucket, pk, sk, "", []byte("hello")) require.NoError(t, err) t.Run("Single", func(t *testing.T) { + t.Parallel() item, ct, err := f.cli.ReadItemSingle(ctx, f.bucket, pk, sk) require.NoError(t, err) require.Equal(t, "hello", string(item)) @@ -126,6 +132,7 @@ func TestClient_ReadItemSingleRevision(t *testing.T) { }) t.Run("Multi", func(t *testing.T) { + t.Parallel() items, ct, err := f.cli.ReadItemMulti(ctx, f.bucket, pk, sk) require.NoError(t, err) require.Len(t, items, 1) @@ -135,10 +142,11 @@ func TestClient_ReadItemSingleRevision(t *testing.T) { } func TestClient_ReadItemMultipleRevisions(t *testing.T) { + t.Parallel() f, ctx := newFixture(t) - pk := randomKey() - sk := randomKey() + pk := randomPk() + sk := randomSk() err := f.cli.InsertItem(ctx, f.bucket, pk, sk, "", []byte("hello1")) require.NoError(t, err) @@ -148,6 +156,7 @@ func TestClient_ReadItemMultipleRevisions(t *testing.T) { require.NoError(t, err) t.Run("Single", func(t *testing.T) { + t.Parallel() item, ct, err := f.cli.ReadItemSingle(ctx, f.bucket, pk, sk) require.ErrorIs(t, err, k2v.ConcurrentItemsErr) require.Nil(t, item) @@ -155,6 +164,7 @@ func TestClient_ReadItemMultipleRevisions(t *testing.T) { }) t.Run("Multi", func(t *testing.T) { + t.Parallel() items, ct, err := f.cli.ReadItemMulti(ctx, f.bucket, pk, sk) require.NoError(t, err) require.Len(t, items, 2) @@ -163,38 +173,3 @@ func TestClient_ReadItemMultipleRevisions(t *testing.T) { require.NotEmpty(t, ct) }) } - -func TestClient_PollItem(t *testing.T) { - f, ctx := newFixture(t) - - pk := randomKey() - sk := randomKey() - - err := f.cli.InsertItem(ctx, f.bucket, pk, sk, "", []byte("hello1")) - require.NoError(t, err) - - _, ct, err := f.cli.ReadItemSingle(ctx, f.bucket, pk, sk) - - pollReadyCh := make(chan struct{}) - go func(ct k2v.CausalityToken) { - select { - case <-ctx.Done(): - t.Errorf("Context canceled: %v", ctx.Err()) - return - case <-pollReadyCh: - t.Logf("PollItem connected") - } - err = f.cli.InsertItem(ctx, f.bucket, pk, sk, ct, []byte("hello2")) - require.NoError(t, err) - }(ct) - - trace := &httptrace.ClientTrace{ - WroteRequest: func(_ httptrace.WroteRequestInfo) { - pollReadyCh <- struct{}{} - }, - } - item, ct, err := f.cli.PollItem(httptrace.WithClientTrace(ctx, trace), f.bucket, pk, sk, ct, 5*time.Second) - require.NoError(t, err) - require.Equal(t, "hello2", string(item)) - require.NotEmpty(t, ct) -} diff --git a/go.mod b/go.mod index 9c37efd..81b98dc 100644 --- a/go.mod +++ b/go.mod @@ -3,14 +3,14 @@ module code.notaphish.fyi/milas/garage-k2v-go go 1.23.1 require ( - github.com/aws/aws-sdk-go-v2 v1.32.2 - github.com/davecgh/go-spew v1.1.1 + github.com/aws/aws-sdk-go-v2 v1.36.3 github.com/stretchr/testify v1.8.0 go.uber.org/goleak v1.3.0 ) require ( - github.com/aws/smithy-go v1.22.0 // indirect + github.com/aws/smithy-go v1.22.3 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/kr/text v0.2.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect diff --git a/go.sum b/go.sum index 052ea4d..a6a637d 100644 --- a/go.sum +++ b/go.sum @@ -1,7 +1,7 @@ -github.com/aws/aws-sdk-go-v2 v1.32.2 h1:AkNLZEyYMLnx/Q/mSKkcMqwNFXMAvFto9bNsHqcTduI= -github.com/aws/aws-sdk-go-v2 v1.32.2/go.mod h1:2SK5n0a2karNTv5tbP1SjsX0uhttou00v/HpXKM1ZUo= -github.com/aws/smithy-go v1.22.0 h1:uunKnWlcoL3zO7q+gG2Pk53joueEOsnNB28QdMsmiMM= -github.com/aws/smithy-go v1.22.0/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= +github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM= +github.com/aws/aws-sdk-go-v2 v1.36.3/go.mod h1:LLXuLpgzEbD766Z5ECcRmi8AzSwfZItDtmABVkRLGzg= +github.com/aws/smithy-go v1.22.3 h1:Z//5NuZCSW6R4PhQ93hShNbyBbn8BWCmCVCt+Q8Io5k= +github.com/aws/smithy-go v1.22.3/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= diff --git a/main_test.go b/main_test.go new file mode 100644 index 0000000..65554a0 --- /dev/null +++ b/main_test.go @@ -0,0 +1,16 @@ +package k2v_test + +import ( + k2v "code.notaphish.fyi/milas/garage-k2v-go" + "go.uber.org/goleak" + "os" + "testing" +) + +const BucketEnvVar = "K2V_TEST_BUCKET" + +var TestBucket = k2v.Bucket(os.Getenv(BucketEnvVar)) + +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} diff --git a/pager.go b/pager.go new file mode 100644 index 0000000..645e4e7 --- /dev/null +++ b/pager.go @@ -0,0 +1,77 @@ +package k2v + +import ( + "context" + "errors" +) + +var StopScroll = errors.New("scroll canceled") + +// ReadIndexResponseHandler is invoked for each batch of index read results. +// +// If an error is returned, scrolling is halted and the error is propagated. +// The sentinel value StopScroll can be returned to end iteration early without propagating an error. +type ReadIndexResponseHandler func(resp *ReadIndexResponse) error + +type BatchSearchResultHandler func(result *BatchSearchResult) error + +type IndexScroller interface { + ReadIndex(ctx context.Context, b Bucket, q ReadIndexQuery) (*ReadIndexResponse, error) +} + +var _ IndexScroller = &Client{} + +type BatchSearchScroller interface { + ReadBatch(ctx context.Context, b Bucket, q []BatchSearch) ([]*BatchSearchResult, error) +} + +var _ BatchSearchScroller = &Client{} + +// ScrollIndex calls the ReadIndex API serially, invoking the provided function for each response (batch) until there are no more results. +func ScrollIndex(ctx context.Context, client IndexScroller, b Bucket, q ReadIndexQuery, fn ReadIndexResponseHandler) error { + for { + resp, err := client.ReadIndex(ctx, b, q) + if err != nil { + return err + } + if err := fn(resp); err != nil { + if errors.Is(err, StopScroll) { + return nil + } + return err + } + if !resp.More || resp.NextStart == nil { + break + } + q.Start = *resp.NextStart + } + return nil +} + +func ScrollBatchSearch(ctx context.Context, client BatchSearchScroller, b Bucket, q []BatchSearch, fn BatchSearchResultHandler) error { + for { + results, err := client.ReadBatch(ctx, b, q) + if err != nil { + return err + } + var nextQ []BatchSearch + for i := range results { + if results[i].More && results[i].NextStart != nil { + batch := q[i] + batch.Start = *results[i].NextStart + nextQ = append(nextQ, batch) + } + if err := fn(results[i]); err != nil { + if errors.Is(err, StopScroll) { + return nil + } + return err + } + } + if len(nextQ) == 0 { + break + } + q = nextQ + } + return nil +} diff --git a/pager_test.go b/pager_test.go new file mode 100644 index 0000000..6ed8cd6 --- /dev/null +++ b/pager_test.go @@ -0,0 +1,83 @@ +package k2v_test + +import ( + k2v "code.notaphish.fyi/milas/garage-k2v-go" + "context" + "fmt" + "github.com/stretchr/testify/require" + "strconv" + "strings" + "testing" +) + +func TestScrollIndex(t *testing.T) { + t.Parallel() + f, ctx := newFixture(t) + + pkPrefix := randomPk() + for i := range 5 { + require.NoError(t, f.cli.InsertItem(ctx, f.bucket, pkPrefix+"-"+strconv.Itoa(i), randomSk(), "", []byte("hello"+strconv.Itoa(i)))) + } + + var responses []*k2v.ReadIndexResponse + err := k2v.ScrollIndex(ctx, f.cli, f.bucket, k2v.ReadIndexQuery{Prefix: pkPrefix, Limit: 1}, func(resp *k2v.ReadIndexResponse) error { + responses = append(responses, resp) + return nil + }) + require.NoError(t, err) + require.Len(t, responses, 5) +} + +func ExampleScrollIndex() { + ctx := context.Background() + client := k2v.NewClient(k2v.EndpointFromEnv(), k2v.KeyFromEnv()) + defer client.Close() + + pkPrefix := randomPk() + for i := range 5 { + _ = client.InsertItem(ctx, TestBucket, pkPrefix+"-"+strconv.Itoa(i), randomSk(), "", []byte("hello")) + } + + var responses []*k2v.ReadIndexResponse + _ = k2v.ScrollIndex(ctx, client, TestBucket, k2v.ReadIndexQuery{Prefix: pkPrefix, Limit: 25}, func(resp *k2v.ReadIndexResponse) error { + responses = append(responses, resp) + return nil + }) + fmt.Println(len(responses[0].PartitionKeys)) + // Output: + // 5 +} + +func TestScrollItems(t *testing.T) { + t.Parallel() + f, ctx := newFixture(t) + + pk1 := randomKey("pk1") + sk1 := randomKey("sk1") + require.NoError(t, f.cli.InsertItem(ctx, f.bucket, pk1, sk1, "", []byte(strings.Join([]string{"hello", pk1, sk1}, "-")))) + + pk2 := randomKey("pk2") + for i := range 5 { + skN := randomKey(fmt.Sprintf("sk%d", i+2)) + require.NoError(t, f.cli.InsertItem(ctx, f.bucket, pk2, skN, "", []byte(strings.Join([]string{"hello", pk2, skN, strconv.Itoa(i)}, "-")))) + } + + q := []k2v.BatchSearch{ + { + PartitionKey: pk1, + }, + { + PartitionKey: pk2, + Limit: 1, + }, + } + + var results []*k2v.BatchSearchResult + err := k2v.ScrollBatchSearch(ctx, f.cli, f.bucket, q, func(result *k2v.BatchSearchResult) error { + results = append(results, result) + return nil + }) + require.NoError(t, err) + require.NotEmpty(t, results) + require.Len(t, results, 6) +} diff --git a/poll_batch.go b/poll_batch.go new file mode 100644 index 0000000..83df65b --- /dev/null +++ b/poll_batch.go @@ -0,0 +1,84 @@ +package k2v + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "net/url" + "time" +) + +type PollRangeQuery struct { + // Prefix restricts items to poll to those whose sort keys start with this prefix. + Prefix string `json:"prefix,omitempty"` + + // Start is the sort key of the first item to poll. + Start string `json:"start,omitempty"` + + // End is the sort key of the last item to poll (excluded). + End string `json:"end,omitempty"` + + // SeenMarker is an opaque string returned by a previous PollRange call, that represents items already seen. + SeenMarker string `json:"seenMarker,omitempty"` +} + +type PollRangeResponse struct { + SeenMarker string `json:"seenMarker"` + Items []SearchResultItem `json:"items"` +} + +func (c *Client) PollRange(ctx context.Context, b Bucket, pk string, q PollRangeQuery, timeout time.Duration) (*PollRangeResponse, error) { + u, err := url.Parse(c.endpoint) + if err != nil { + return nil, err + } + u.Path = string(b) + "/" + pk + query := make(url.Values) + query.Set("poll_range", "") + u.RawQuery = query.Encode() + + reqBody, err := json.Marshal(struct { + PollRangeQuery + Timeout int `json:"timeout,omitempty"` + }{ + PollRangeQuery: q, + Timeout: int(timeout.Seconds()), + }) + if err != nil { + return nil, err + } + + req, err := http.NewRequestWithContext(ctx, "SEARCH", u.String(), bytes.NewReader(reqBody)) + if err != nil { + return nil, err + } + + resp, err := c.executeRequest(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, err + } + + switch resp.StatusCode { + case http.StatusOK: + break + case http.StatusNotModified: + return nil, NotModifiedTimeoutErr + default: + return nil, fmt.Errorf("http status code %d: %s", resp.StatusCode, body) + } + + var result PollRangeResponse + if err := json.Unmarshal(body, &result); err != nil { + return nil, err + } + return &result, nil +} diff --git a/poll_batch_test.go b/poll_batch_test.go new file mode 100644 index 0000000..c0c91cd --- /dev/null +++ b/poll_batch_test.go @@ -0,0 +1,118 @@ +package k2v_test + +import ( + k2v "code.notaphish.fyi/milas/garage-k2v-go" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "net/http/httptrace" + "strconv" + "testing" + "time" +) + +func TestClient_PollRange(t *testing.T) { + t.Parallel() + + f, ctx := newFixture(t) + + pk := randomPk() + sk := randomSk() + + for i := range 5 { + err := f.cli.InsertItem(ctx, f.bucket, pk, sk+"-"+strconv.Itoa(i), "", []byte("hello1")) + require.NoError(t, err) + } + + // first read should complete immediately + q := k2v.PollRangeQuery{ + Start: sk, + } + result, err := f.cli.PollRange(ctx, f.bucket, pk, q, 5*time.Second) + require.NoError(t, err) + require.NotEmpty(t, result.SeenMarker) + require.Len(t, result.Items, 5) + for i := range result.Items { + require.Len(t, result.Items[i].Values, 1) + require.Equal(t, "hello1", string(result.Items[i].Values[0])) + } + + updateErrCh := make(chan error, 1) + pollReadyCh := make(chan struct{}) + go func(sk string, ct k2v.CausalityToken) { + defer close(updateErrCh) + select { + case <-ctx.Done(): + t.Errorf("Context canceled: %v", ctx.Err()) + return + case <-pollReadyCh: + t.Logf("PollRange connected") + } + updateErrCh <- f.cli.InsertItem(ctx, f.bucket, pk, sk, ct, []byte("hello2")) + }(result.Items[3].SortKey, k2v.CausalityToken(result.Items[3].CausalityToken)) + + trace := &httptrace.ClientTrace{ + WroteRequest: func(_ httptrace.WroteRequestInfo) { + pollReadyCh <- struct{}{} + }, + } + + q.SeenMarker = result.SeenMarker + result, err = f.cli.PollRange(httptrace.WithClientTrace(ctx, trace), f.bucket, pk, q, 5*time.Second) + if assert.NoError(t, err) { + require.NotEmpty(t, result.SeenMarker) + require.Len(t, result.Items, 1) + require.Len(t, result.Items[0].Values, 1) + require.Equal(t, sk+"-3", result.Items[0].SortKey) + require.Equal(t, "hello2", string(result.Items[0].Values[0])) + } + require.NoError(t, <-updateErrCh) + + require.NoError(t, err) + require.NoError(t, f.cli.InsertItem(ctx, f.bucket, pk, result.Items[0].SortKey, k2v.CausalityToken(result.Items[0].CausalityToken), []byte("hello3"))) + + q.SeenMarker = result.SeenMarker + result, err = f.cli.PollRange(ctx, f.bucket, pk, q, 5*time.Second) + if assert.NoError(t, err) { + require.NotEmpty(t, result.SeenMarker) + require.Len(t, result.Items, 1) + require.Len(t, result.Items[0].Values, 1) + require.Equal(t, sk+"-3", result.Items[0].SortKey) + require.Equal(t, "hello3", string(result.Items[0].Values[0])) + } +} + +func TestClient_PollRange_Timeout(t *testing.T) { + if testing.Short() { + t.Skip("Skipping in short mode: 1 sec minimum to trigger timeout") + return + } + t.Parallel() + + f, ctx := newFixture(t) + + pk := randomPk() + sk := randomSk() + + for i := range 5 { + err := f.cli.InsertItem(ctx, f.bucket, pk, sk+"-"+strconv.Itoa(i), "", []byte("hello1")) + require.NoError(t, err) + } + + // first read should complete immediately + q := k2v.PollRangeQuery{ + Start: sk, + } + result, err := f.cli.PollRange(ctx, f.bucket, pk, q, 5*time.Second) + require.NoError(t, err) + require.NotEmpty(t, result.SeenMarker) + require.Len(t, result.Items, 5) + for i := range result.Items { + require.Len(t, result.Items[i].Values, 1) + require.Equal(t, "hello1", string(result.Items[i].Values[0])) + } + + q.SeenMarker = result.SeenMarker + result, err = f.cli.PollRange(ctx, f.bucket, pk, q, 1*time.Second) + require.ErrorIs(t, err, k2v.NotModifiedTimeoutErr) + require.Nil(t, result) +} diff --git a/poll_single_test.go b/poll_single_test.go new file mode 100644 index 0000000..d65521e --- /dev/null +++ b/poll_single_test.go @@ -0,0 +1,72 @@ +package k2v_test + +import ( + k2v "code.notaphish.fyi/milas/garage-k2v-go" + "github.com/stretchr/testify/require" + "net/http/httptrace" + "testing" + "time" +) + +func TestClient_PollItem(t *testing.T) { + t.Parallel() + + f, ctx := newFixture(t) + + pk := randomPk() + sk := randomSk() + + err := f.cli.InsertItem(ctx, f.bucket, pk, sk, "", []byte("hello1")) + require.NoError(t, err) + + _, ct, err := f.cli.ReadItemSingle(ctx, f.bucket, pk, sk) + require.NoError(t, err) + + updateErrCh := make(chan error, 1) + pollReadyCh := make(chan struct{}) + go func(ct k2v.CausalityToken) { + defer close(updateErrCh) + select { + case <-ctx.Done(): + t.Errorf("Context canceled: %v", ctx.Err()) + return + case <-pollReadyCh: + t.Logf("PollItem connected") + } + updateErrCh <- f.cli.InsertItem(ctx, f.bucket, pk, sk, ct, []byte("hello2")) + }(ct) + + trace := &httptrace.ClientTrace{ + WroteRequest: func(_ httptrace.WroteRequestInfo) { + pollReadyCh <- struct{}{} + }, + } + item, ct, err := f.cli.PollItem(httptrace.WithClientTrace(ctx, trace), f.bucket, pk, sk, ct, 5*time.Second) + require.NoError(t, err) + require.Equal(t, "hello2", string(item)) + require.NotEmpty(t, ct) + require.NoError(t, <-updateErrCh) +} + +func TestClient_PollItem_Timeout(t *testing.T) { + if testing.Short() { + t.Skip("Skipping in short mode: 1 sec minimum to trigger timeout") + return + } + t.Parallel() + + f, ctx := newFixture(t) + + pk := randomPk() + sk := randomSk() + + err := f.cli.InsertItem(ctx, f.bucket, pk, sk, "", []byte("hello1")) + require.NoError(t, err) + + _, ct, err := f.cli.ReadItemSingle(ctx, f.bucket, pk, sk) + require.NoError(t, err) + + item, _, err := f.cli.PollItem(ctx, f.bucket, pk, sk, ct, 1*time.Second) + require.ErrorIs(t, err, k2v.NotModifiedTimeoutErr) + require.Empty(t, item) +} diff --git a/read_batch.go b/read_batch.go index 876c95c..e4382a5 100644 --- a/read_batch.go +++ b/read_batch.go @@ -10,7 +10,7 @@ import ( "net/url" ) -type ReadBatchSearch struct { +type BatchSearch struct { PartitionKey string `json:"partitionKey"` // Prefix restricts listing to partition keys that start with this value. @@ -54,12 +54,12 @@ type BatchSearchResult struct { } type SearchResultItem struct { - SortKey string `json:"sk"` - CausalityToken string `json:"ct"` - Values []Item `json:"v"` + SortKey string `json:"sk"` + CausalityToken CausalityToken `json:"ct"` + Values []Item `json:"v"` } -func (c *Client) ReadBatch(ctx context.Context, b Bucket, q []ReadBatchSearch) ([]BatchSearchResult, error) { +func (c *Client) ReadBatch(ctx context.Context, b Bucket, q []BatchSearch) ([]*BatchSearchResult, error) { u, err := url.Parse(c.endpoint) if err != nil { return nil, err @@ -91,7 +91,7 @@ func (c *Client) ReadBatch(ctx context.Context, b Bucket, q []ReadBatchSearch) ( return nil, fmt.Errorf("http status code %d: %s", resp.StatusCode, body) } - var items []BatchSearchResult + var items []*BatchSearchResult if err := json.Unmarshal(body, &items); err != nil { return nil, err } @@ -112,9 +112,9 @@ type BulkGetItem struct { } func BulkGet(ctx context.Context, cli *Client, b Bucket, keys []ItemKey) ([]BulkGetItem, error) { - q := make([]ReadBatchSearch, len(keys)) + q := make([]BatchSearch, len(keys)) for i := range keys { - q[i] = ReadBatchSearch{ + q[i] = BatchSearch{ PartitionKey: keys[i].PartitionKey, Start: keys[i].SortKey, SingleItem: true, diff --git a/read_batch_test.go b/read_batch_test.go index 47bd07b..524fbac 100644 --- a/read_batch_test.go +++ b/read_batch_test.go @@ -2,52 +2,79 @@ package k2v_test import ( k2v "code.notaphish.fyi/milas/garage-k2v-go" - "github.com/davecgh/go-spew/spew" + "fmt" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "math/rand/v2" "strconv" + "strings" "testing" ) func TestClient_ReadBatch(t *testing.T) { f, ctx := newFixture(t) - pk1 := randomKey() - sk1 := randomKey() + pk1 := randomKey("pk1") + sk1 := randomKey("sk1") + require.NoError(t, f.cli.InsertItem(ctx, f.bucket, pk1, sk1, "", []byte(strings.Join([]string{"hello", pk1, sk1}, "-")))) - require.NoError(t, f.cli.InsertItem(ctx, f.bucket, pk1, sk1, "", []byte("hello"))) - - pk2 := randomKey() + pk2 := randomKey("pk2") + sk2 := randomKey("sk2") for i := range 5 { - sk := randomKey() - require.NoError(t, f.cli.InsertItem(ctx, f.bucket, pk2, sk, "", []byte("hello-"+strconv.Itoa(i)))) + require.NoError(t, f.cli.InsertItem(ctx, f.bucket, pk2, sk2, "", []byte(strings.Join([]string{"hello", pk2, sk2, strconv.Itoa(i)}, "-")))) } - pk3 := randomKey() - sk3 := randomKey() + pk3 := randomKey("pk3") for i := range 5 { - require.NoError(t, f.cli.InsertItem(ctx, f.bucket, pk3, sk3, "", []byte("hello-"+strconv.Itoa(i)))) + skN := randomKey(fmt.Sprintf("sk%d", i+3)) + require.NoError(t, f.cli.InsertItem(ctx, f.bucket, pk3, skN, "", []byte(strings.Join([]string{"hello", pk3, skN, strconv.Itoa(i)}, "-")))) } - q := []k2v.ReadBatchSearch{ + q := []k2v.BatchSearch{ { PartitionKey: pk1, }, { PartitionKey: pk2, + SingleItem: true, + Start: sk2, }, { PartitionKey: pk3, - SingleItem: true, - Start: sk3, }, } - items, err := f.cli.ReadBatch(ctx, f.bucket, q) + results, err := f.cli.ReadBatch(ctx, f.bucket, q) require.NoError(t, err) - require.NotEmpty(t, items) + require.NotEmpty(t, results) + require.Len(t, results, 3) - spew.Dump(items) + assert.Equal(t, pk1, results[0].PartitionKey) + if assert.Len(t, results[0].Items, 1) && assert.Len(t, results[0].Items[0].Values, 1) { + assert.Equal(t, sk1, results[0].Items[0].SortKey) + assert.NotEmpty(t, results[0].Items[0].CausalityToken) + assert.Contains(t, results[0].Items[0].Values[0].GoString(), "hello") + } + + assert.Equal(t, pk2, results[1].PartitionKey) + if assert.Len(t, results[1].Items, 1) && assert.Len(t, results[1].Items[0].Values, 5) { + assert.Equal(t, sk2, results[1].Items[0].SortKey) + assert.NotEmpty(t, results[1].Items[0].CausalityToken) + for i := range results[1].Items[0].Values { + assert.Contains(t, results[1].Items[0].Values[i].GoString(), "hello") + } + } + + assert.Equal(t, pk3, results[2].PartitionKey) + if assert.Len(t, results[2].Items, 5) { + for _, item := range results[2].Items { + assert.NotEmpty(t, item.SortKey) + assert.NotEmpty(t, item.CausalityToken) + if assert.Len(t, item.Values, 1) { + assert.Contains(t, item.Values[0].GoString(), "hello") + } + } + } } func TestBulkGet(t *testing.T) { @@ -56,8 +83,8 @@ func TestBulkGet(t *testing.T) { keys := make([]k2v.ItemKey, 500) for i := range keys { keys[i] = k2v.ItemKey{ - PartitionKey: randomKey(), - SortKey: randomKey(), + PartitionKey: randomPk(), + SortKey: randomSk(), } require.NoError(t, f.cli.InsertItem(ctx, f.bucket, keys[i].PartitionKey, keys[i].SortKey, "", []byte("hello"+strconv.Itoa(i)))) }