Compare commits

...

7 commits
v0.1.0 ... main

13 changed files with 597 additions and 104 deletions

View file

@ -33,12 +33,39 @@ type Client
func (c *Client) InsertBatch(ctx context.Context, b Bucket, items []BatchInsertItem) error func (c *Client) InsertBatch(ctx context.Context, b Bucket, items []BatchInsertItem) error
func (c *Client) InsertItem(ctx context.Context, b Bucket, pk string, sk string, ct CausalityToken, item []byte) error func (c *Client) InsertItem(ctx context.Context, b Bucket, pk string, sk string, ct CausalityToken, item []byte) error
func (c *Client) PollItem(ctx context.Context, b Bucket, pk string, sk string, ct CausalityToken, timeout time.Duration) (Item, CausalityToken, error) func (c *Client) PollItem(ctx context.Context, b Bucket, pk string, sk string, ct CausalityToken, timeout time.Duration) (Item, CausalityToken, error)
func (c *Client) PollRange(ctx context.Context, b Bucket, pk string, q PollRangeQuery, timeout time.Duration) (*PollRangeResponse, error)
func (c *Client) ReadBatch(ctx context.Context, b Bucket, q []ReadBatchSearch) ([]BatchSearchResult, error) func (c *Client) ReadBatch(ctx context.Context, b Bucket, q []ReadBatchSearch) ([]BatchSearchResult, error)
func (c *Client) ReadIndex(ctx context.Context, b Bucket, q ReadIndexQuery) (*ReadIndexResponse, error) func (c *Client) ReadIndex(ctx context.Context, b Bucket, q ReadIndexQuery) (*ReadIndexResponse, error)
func (c *Client) ReadItemMulti(ctx context.Context, b Bucket, pk string, sk string) ([]Item, CausalityToken, error) func (c *Client) ReadItemMulti(ctx context.Context, b Bucket, pk string, sk string) ([]Item, CausalityToken, error)
func (c *Client) ReadItemSingle(ctx context.Context, b Bucket, pk string, sk string) (Item, CausalityToken, error) func (c *Client) ReadItemSingle(ctx context.Context, b Bucket, pk string, sk string) (Item, CausalityToken, error)
``` ```
## Scrolling (Client-side / Go API)
To handle iteration in the K2V API, helper functions for simple cases are provided.
For example, to perform a bulk search:
```go
handleBatch := func(result *k2v.BatchSearchResult) error {
log.Println(result.Items)
return nil
}
err := k2v.ScrollBatchSearch(ctx, f.cli, f.bucket, []k2v.BatchSearch{
{
PartitionKey: "pk1",
},
{
PartitionKey: "pk2",
Limit: 1,
},
}, handleBatch)
```
This will repeatedly make calls to **ReadBatch** (batch search), using `nextStart` from the responses to generate subsequent requests until all queries are exhausted.
See `ScrollIndex(ctx context.Context, client IndexScroller, b Bucket, query ReadIndexQuery, fn ReadIndexResponseHandler) error` for the equivalent for batch index reads.
No helper is available for `PollRange()` yet.
## Integration Tests ## Integration Tests
```shell ```shell
K2V_ENDPOINT="http://[::1]:3904" \ K2V_ENDPOINT="http://[::1]:3904" \

View file

@ -19,9 +19,13 @@ import (
const CausalityTokenHeader = "X-Garage-Causality-Token" const CausalityTokenHeader = "X-Garage-Causality-Token"
const payloadHashEmpty = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
const payloadHashUnsigned = "UNSIGNED-PAYLOAD"
var TombstoneItemErr = errors.New("item is a tombstone") var TombstoneItemErr = errors.New("item is a tombstone")
var NoSuchItemErr = errors.New("item does not exist") var NoSuchItemErr = errors.New("item does not exist")
var ConcurrentItemsErr = errors.New("item has multiple concurrent values") var ConcurrentItemsErr = errors.New("item has multiple concurrent values")
var NotModifiedTimeoutErr = errors.New("not modified within timeout")
var awsSigner = v4.NewSigner() var awsSigner = v4.NewSigner()
@ -94,7 +98,7 @@ func (c *Client) executeRequest(req *http.Request) (*http.Response, error) {
return nil, err return nil, err
} }
resp, err := http.DefaultClient.Do(req) resp, err := c.httpClient.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -103,14 +107,25 @@ func (c *Client) executeRequest(req *http.Request) (*http.Response, error) {
} }
func (c *Client) signRequest(req *http.Request) error { func (c *Client) signRequest(req *http.Request) error {
if c.key.ID == "" || c.key.Secret == "" {
return errors.New("no credentials provided")
}
creds := aws.Credentials{ creds := aws.Credentials{
AccessKeyID: c.key.ID, AccessKeyID: c.key.ID,
SecretAccessKey: c.key.Secret, SecretAccessKey: c.key.Secret,
} }
const noBody = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
req.Header.Set("X-Amz-Content-Sha256", noBody)
err := awsSigner.SignHTTP(req.Context(), creds, req, noBody, "k2v", "garage", time.Now()) var payloadHash string
if req.Body == nil || req.Body == http.NoBody {
payloadHash = payloadHashEmpty
} else {
payloadHash = payloadHashUnsigned
}
req.Header.Set("X-Amz-Content-Sha256", payloadHash)
err := awsSigner.SignHTTP(req.Context(), creds, req, payloadHash, "k2v", "garage", time.Now())
if err != nil { if err != nil {
return err return err
} }
@ -143,14 +158,14 @@ type ReadIndexResponsePartitionKey struct {
} }
type ReadIndexResponse struct { type ReadIndexResponse struct {
Prefix any `json:"prefix"` Prefix *string `json:"prefix"`
Start any `json:"start"` Start *string `json:"start"`
End any `json:"end"` End *string `json:"end"`
Limit any `json:"limit"` Limit *int `json:"limit"`
Reverse bool `json:"reverse"` Reverse bool `json:"reverse"`
PartitionKeys []ReadIndexResponsePartitionKey `json:"partitionKeys"` PartitionKeys []ReadIndexResponsePartitionKey `json:"partitionKeys"`
More bool `json:"more"` More bool `json:"more"`
NextStart any `json:"nextStart"` NextStart *string `json:"nextStart"`
} }
func (c *Client) ReadIndex(ctx context.Context, b Bucket, q ReadIndexQuery) (*ReadIndexResponse, error) { func (c *Client) ReadIndex(ctx context.Context, b Bucket, q ReadIndexQuery) (*ReadIndexResponse, error) {
@ -247,9 +262,6 @@ func (c *Client) ReadItemMulti(ctx context.Context, b Bucket, pk string, sk stri
return []Item{body}, ct, nil return []Item{body}, ct, nil
case "application/json": case "application/json":
var items []Item var items []Item
if err != nil {
return nil, "", err
}
if err := json.Unmarshal(body, &items); err != nil { if err := json.Unmarshal(body, &items); err != nil {
return nil, "", err return nil, "", err
} }
@ -300,6 +312,8 @@ func (c *Client) readItemSingle(ctx context.Context, b Bucket, pk string, sk str
return nil, "", NoSuchItemErr return nil, "", NoSuchItemErr
case http.StatusConflict: case http.StatusConflict:
return nil, ct, ConcurrentItemsErr return nil, ct, ConcurrentItemsErr
case http.StatusNotModified:
return nil, "", NotModifiedTimeoutErr
default: default:
return nil, "", fmt.Errorf("http status code %d", resp.StatusCode) return nil, "", fmt.Errorf("http status code %d", resp.StatusCode)
} }

View file

@ -1,16 +1,12 @@
package k2v_test package k2v_test
import ( import (
k2v "code.notaphish.fyi/milas/garage-k2v-go"
"context" "context"
"github.com/stretchr/testify/require"
"math/rand/v2" "math/rand/v2"
"net/http/httptrace"
"strconv" "strconv"
"testing" "testing"
"time"
k2v "code.notaphish.fyi/milas/garage-k2v-go"
"github.com/stretchr/testify/require"
"go.uber.org/goleak"
) )
type fixture struct { type fixture struct {
@ -22,11 +18,6 @@ type fixture struct {
func newFixture(t testing.TB) (*fixture, context.Context) { func newFixture(t testing.TB) (*fixture, context.Context) {
t.Helper() t.Helper()
t.Cleanup(func() {
goleak.VerifyNone(t)
})
ctx := testContext(t) ctx := testContext(t)
cli := k2v.NewClient(k2v.EndpointFromEnv(), k2v.KeyFromEnv()) cli := k2v.NewClient(k2v.EndpointFromEnv(), k2v.KeyFromEnv())
@ -36,7 +27,7 @@ func newFixture(t testing.TB) (*fixture, context.Context) {
t: t, t: t,
ctx: ctx, ctx: ctx,
cli: cli, cli: cli,
bucket: k2v.Bucket("k2v-test"), bucket: TestBucket,
} }
return f, ctx return f, ctx
@ -48,23 +39,34 @@ func testContext(t testing.TB) context.Context {
return ctx return ctx
} }
func randomKey() string { func randomKey(prefix string) string {
return "key-" + strconv.Itoa(rand.IntN(1000000)) return prefix + "-" + strconv.Itoa(rand.IntN(1000000))
}
func randomPk() string {
return randomKey("pk")
}
func randomSk() string {
return randomKey("sk")
} }
func TestClient_InsertItem(t *testing.T) { func TestClient_InsertItem(t *testing.T) {
t.Parallel()
f, ctx := newFixture(t) f, ctx := newFixture(t)
err := f.cli.InsertItem(ctx, f.bucket, randomKey(), randomKey(), "", []byte("hello")) err := f.cli.InsertItem(ctx, f.bucket, randomPk(), randomSk(), "", []byte("hello"))
require.NoError(t, err) require.NoError(t, err)
} }
func TestClient_ReadItemNotExist(t *testing.T) { func TestClient_ReadItemNotExist(t *testing.T) {
t.Parallel()
f, ctx := newFixture(t) f, ctx := newFixture(t)
pk := randomKey() pk := randomPk()
sk := randomKey() sk := randomSk()
t.Run("Single", func(t *testing.T) { t.Run("Single", func(t *testing.T) {
t.Parallel()
item, ct, err := f.cli.ReadItemSingle(ctx, f.bucket, pk, sk) item, ct, err := f.cli.ReadItemSingle(ctx, f.bucket, pk, sk)
require.ErrorIs(t, err, k2v.NoSuchItemErr) require.ErrorIs(t, err, k2v.NoSuchItemErr)
require.Nil(t, item) require.Nil(t, item)
@ -72,6 +74,7 @@ func TestClient_ReadItemNotExist(t *testing.T) {
}) })
t.Run("Multi", func(t *testing.T) { t.Run("Multi", func(t *testing.T) {
t.Parallel()
items, ct, err := f.cli.ReadItemMulti(ctx, f.bucket, pk, sk) items, ct, err := f.cli.ReadItemMulti(ctx, f.bucket, pk, sk)
require.ErrorIs(t, err, k2v.NoSuchItemErr) require.ErrorIs(t, err, k2v.NoSuchItemErr)
require.Empty(t, items) require.Empty(t, items)
@ -80,10 +83,11 @@ func TestClient_ReadItemNotExist(t *testing.T) {
} }
func TestClient_ReadItemTombstone(t *testing.T) { func TestClient_ReadItemTombstone(t *testing.T) {
t.Parallel()
f, ctx := newFixture(t) f, ctx := newFixture(t)
pk := randomKey() pk := randomPk()
sk := randomKey() sk := randomSk()
t.Logf("Creating item: PK=%s, SK=%s", pk, sk) t.Logf("Creating item: PK=%s, SK=%s", pk, sk)
@ -110,15 +114,17 @@ func TestClient_ReadItemTombstone(t *testing.T) {
} }
func TestClient_ReadItemSingleRevision(t *testing.T) { func TestClient_ReadItemSingleRevision(t *testing.T) {
t.Parallel()
f, ctx := newFixture(t) f, ctx := newFixture(t)
pk := randomKey() pk := randomPk()
sk := randomKey() sk := randomSk()
err := f.cli.InsertItem(ctx, f.bucket, pk, sk, "", []byte("hello")) err := f.cli.InsertItem(ctx, f.bucket, pk, sk, "", []byte("hello"))
require.NoError(t, err) require.NoError(t, err)
t.Run("Single", func(t *testing.T) { t.Run("Single", func(t *testing.T) {
t.Parallel()
item, ct, err := f.cli.ReadItemSingle(ctx, f.bucket, pk, sk) item, ct, err := f.cli.ReadItemSingle(ctx, f.bucket, pk, sk)
require.NoError(t, err) require.NoError(t, err)
require.Equal(t, "hello", string(item)) require.Equal(t, "hello", string(item))
@ -126,6 +132,7 @@ func TestClient_ReadItemSingleRevision(t *testing.T) {
}) })
t.Run("Multi", func(t *testing.T) { t.Run("Multi", func(t *testing.T) {
t.Parallel()
items, ct, err := f.cli.ReadItemMulti(ctx, f.bucket, pk, sk) items, ct, err := f.cli.ReadItemMulti(ctx, f.bucket, pk, sk)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, items, 1) require.Len(t, items, 1)
@ -135,10 +142,11 @@ func TestClient_ReadItemSingleRevision(t *testing.T) {
} }
func TestClient_ReadItemMultipleRevisions(t *testing.T) { func TestClient_ReadItemMultipleRevisions(t *testing.T) {
t.Parallel()
f, ctx := newFixture(t) f, ctx := newFixture(t)
pk := randomKey() pk := randomPk()
sk := randomKey() sk := randomSk()
err := f.cli.InsertItem(ctx, f.bucket, pk, sk, "", []byte("hello1")) err := f.cli.InsertItem(ctx, f.bucket, pk, sk, "", []byte("hello1"))
require.NoError(t, err) require.NoError(t, err)
@ -148,6 +156,7 @@ func TestClient_ReadItemMultipleRevisions(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
t.Run("Single", func(t *testing.T) { t.Run("Single", func(t *testing.T) {
t.Parallel()
item, ct, err := f.cli.ReadItemSingle(ctx, f.bucket, pk, sk) item, ct, err := f.cli.ReadItemSingle(ctx, f.bucket, pk, sk)
require.ErrorIs(t, err, k2v.ConcurrentItemsErr) require.ErrorIs(t, err, k2v.ConcurrentItemsErr)
require.Nil(t, item) require.Nil(t, item)
@ -155,6 +164,7 @@ func TestClient_ReadItemMultipleRevisions(t *testing.T) {
}) })
t.Run("Multi", func(t *testing.T) { t.Run("Multi", func(t *testing.T) {
t.Parallel()
items, ct, err := f.cli.ReadItemMulti(ctx, f.bucket, pk, sk) items, ct, err := f.cli.ReadItemMulti(ctx, f.bucket, pk, sk)
require.NoError(t, err) require.NoError(t, err)
require.Len(t, items, 2) require.Len(t, items, 2)
@ -163,38 +173,3 @@ func TestClient_ReadItemMultipleRevisions(t *testing.T) {
require.NotEmpty(t, ct) require.NotEmpty(t, ct)
}) })
} }
func TestClient_PollItem(t *testing.T) {
f, ctx := newFixture(t)
pk := randomKey()
sk := randomKey()
err := f.cli.InsertItem(ctx, f.bucket, pk, sk, "", []byte("hello1"))
require.NoError(t, err)
_, ct, err := f.cli.ReadItemSingle(ctx, f.bucket, pk, sk)
pollReadyCh := make(chan struct{})
go func(ct k2v.CausalityToken) {
select {
case <-ctx.Done():
t.Errorf("Context canceled: %v", ctx.Err())
return
case <-pollReadyCh:
t.Logf("PollItem connected")
}
err = f.cli.InsertItem(ctx, f.bucket, pk, sk, ct, []byte("hello2"))
require.NoError(t, err)
}(ct)
trace := &httptrace.ClientTrace{
WroteRequest: func(_ httptrace.WroteRequestInfo) {
pollReadyCh <- struct{}{}
},
}
item, ct, err := f.cli.PollItem(httptrace.WithClientTrace(ctx, trace), f.bucket, pk, sk, ct, 5*time.Second)
require.NoError(t, err)
require.Equal(t, "hello2", string(item))
require.NotEmpty(t, ct)
}

6
go.mod
View file

@ -3,14 +3,14 @@ module code.notaphish.fyi/milas/garage-k2v-go
go 1.23.1 go 1.23.1
require ( require (
github.com/aws/aws-sdk-go-v2 v1.32.2 github.com/aws/aws-sdk-go-v2 v1.36.3
github.com/davecgh/go-spew v1.1.1
github.com/stretchr/testify v1.8.0 github.com/stretchr/testify v1.8.0
go.uber.org/goleak v1.3.0 go.uber.org/goleak v1.3.0
) )
require ( require (
github.com/aws/smithy-go v1.22.0 // indirect github.com/aws/smithy-go v1.22.3 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
github.com/kr/text v0.2.0 // indirect github.com/kr/text v0.2.0 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect

8
go.sum
View file

@ -1,7 +1,7 @@
github.com/aws/aws-sdk-go-v2 v1.32.2 h1:AkNLZEyYMLnx/Q/mSKkcMqwNFXMAvFto9bNsHqcTduI= github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM=
github.com/aws/aws-sdk-go-v2 v1.32.2/go.mod h1:2SK5n0a2karNTv5tbP1SjsX0uhttou00v/HpXKM1ZUo= github.com/aws/aws-sdk-go-v2 v1.36.3/go.mod h1:LLXuLpgzEbD766Z5ECcRmi8AzSwfZItDtmABVkRLGzg=
github.com/aws/smithy-go v1.22.0 h1:uunKnWlcoL3zO7q+gG2Pk53joueEOsnNB28QdMsmiMM= github.com/aws/smithy-go v1.22.3 h1:Z//5NuZCSW6R4PhQ93hShNbyBbn8BWCmCVCt+Q8Io5k=
github.com/aws/smithy-go v1.22.0/go.mod h1:irrKGvNn1InZwb2d7fkIRNucdfwR8R+Ts3wxYa/cJHg= github.com/aws/smithy-go v1.22.3/go.mod h1:t1ufH5HMublsJYulve2RKmHDC15xu1f26kHCp/HgceI=
github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= github.com/creack/pty v1.1.9/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=

16
main_test.go Normal file
View file

@ -0,0 +1,16 @@
package k2v_test
import (
k2v "code.notaphish.fyi/milas/garage-k2v-go"
"go.uber.org/goleak"
"os"
"testing"
)
const BucketEnvVar = "K2V_TEST_BUCKET"
var TestBucket = k2v.Bucket(os.Getenv(BucketEnvVar))
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}

77
pager.go Normal file
View file

@ -0,0 +1,77 @@
package k2v
import (
"context"
"errors"
)
var StopScroll = errors.New("scroll canceled")
// ReadIndexResponseHandler is invoked for each batch of index read results.
//
// If an error is returned, scrolling is halted and the error is propagated.
// The sentinel value StopScroll can be returned to end iteration early without propagating an error.
type ReadIndexResponseHandler func(resp *ReadIndexResponse) error
type BatchSearchResultHandler func(result *BatchSearchResult) error
type IndexScroller interface {
ReadIndex(ctx context.Context, b Bucket, q ReadIndexQuery) (*ReadIndexResponse, error)
}
var _ IndexScroller = &Client{}
type BatchSearchScroller interface {
ReadBatch(ctx context.Context, b Bucket, q []BatchSearch) ([]*BatchSearchResult, error)
}
var _ BatchSearchScroller = &Client{}
// ScrollIndex calls the ReadIndex API serially, invoking the provided function for each response (batch) until there are no more results.
func ScrollIndex(ctx context.Context, client IndexScroller, b Bucket, q ReadIndexQuery, fn ReadIndexResponseHandler) error {
for {
resp, err := client.ReadIndex(ctx, b, q)
if err != nil {
return err
}
if err := fn(resp); err != nil {
if errors.Is(err, StopScroll) {
return nil
}
return err
}
if !resp.More || resp.NextStart == nil {
break
}
q.Start = *resp.NextStart
}
return nil
}
func ScrollBatchSearch(ctx context.Context, client BatchSearchScroller, b Bucket, q []BatchSearch, fn BatchSearchResultHandler) error {
for {
results, err := client.ReadBatch(ctx, b, q)
if err != nil {
return err
}
var nextQ []BatchSearch
for i := range results {
if results[i].More && results[i].NextStart != nil {
batch := q[i]
batch.Start = *results[i].NextStart
nextQ = append(nextQ, batch)
}
if err := fn(results[i]); err != nil {
if errors.Is(err, StopScroll) {
return nil
}
return err
}
}
if len(nextQ) == 0 {
break
}
q = nextQ
}
return nil
}

83
pager_test.go Normal file
View file

@ -0,0 +1,83 @@
package k2v_test
import (
k2v "code.notaphish.fyi/milas/garage-k2v-go"
"context"
"fmt"
"github.com/stretchr/testify/require"
"strconv"
"strings"
"testing"
)
func TestScrollIndex(t *testing.T) {
t.Parallel()
f, ctx := newFixture(t)
pkPrefix := randomPk()
for i := range 5 {
require.NoError(t, f.cli.InsertItem(ctx, f.bucket, pkPrefix+"-"+strconv.Itoa(i), randomSk(), "", []byte("hello"+strconv.Itoa(i))))
}
var responses []*k2v.ReadIndexResponse
err := k2v.ScrollIndex(ctx, f.cli, f.bucket, k2v.ReadIndexQuery{Prefix: pkPrefix, Limit: 1}, func(resp *k2v.ReadIndexResponse) error {
responses = append(responses, resp)
return nil
})
require.NoError(t, err)
require.Len(t, responses, 5)
}
func ExampleScrollIndex() {
ctx := context.Background()
client := k2v.NewClient(k2v.EndpointFromEnv(), k2v.KeyFromEnv())
defer client.Close()
pkPrefix := randomPk()
for i := range 5 {
_ = client.InsertItem(ctx, TestBucket, pkPrefix+"-"+strconv.Itoa(i), randomSk(), "", []byte("hello"))
}
var responses []*k2v.ReadIndexResponse
_ = k2v.ScrollIndex(ctx, client, TestBucket, k2v.ReadIndexQuery{Prefix: pkPrefix, Limit: 25}, func(resp *k2v.ReadIndexResponse) error {
responses = append(responses, resp)
return nil
})
fmt.Println(len(responses[0].PartitionKeys))
// Output:
// 5
}
func TestScrollItems(t *testing.T) {
t.Parallel()
f, ctx := newFixture(t)
pk1 := randomKey("pk1")
sk1 := randomKey("sk1")
require.NoError(t, f.cli.InsertItem(ctx, f.bucket, pk1, sk1, "", []byte(strings.Join([]string{"hello", pk1, sk1}, "-"))))
pk2 := randomKey("pk2")
for i := range 5 {
skN := randomKey(fmt.Sprintf("sk%d", i+2))
require.NoError(t, f.cli.InsertItem(ctx, f.bucket, pk2, skN, "", []byte(strings.Join([]string{"hello", pk2, skN, strconv.Itoa(i)}, "-"))))
}
q := []k2v.BatchSearch{
{
PartitionKey: pk1,
},
{
PartitionKey: pk2,
Limit: 1,
},
}
var results []*k2v.BatchSearchResult
err := k2v.ScrollBatchSearch(ctx, f.cli, f.bucket, q, func(result *k2v.BatchSearchResult) error {
results = append(results, result)
return nil
})
require.NoError(t, err)
require.NotEmpty(t, results)
require.Len(t, results, 6)
}

84
poll_batch.go Normal file
View file

@ -0,0 +1,84 @@
package k2v
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"time"
)
type PollRangeQuery struct {
// Prefix restricts items to poll to those whose sort keys start with this prefix.
Prefix string `json:"prefix,omitempty"`
// Start is the sort key of the first item to poll.
Start string `json:"start,omitempty"`
// End is the sort key of the last item to poll (excluded).
End string `json:"end,omitempty"`
// SeenMarker is an opaque string returned by a previous PollRange call, that represents items already seen.
SeenMarker string `json:"seenMarker,omitempty"`
}
type PollRangeResponse struct {
SeenMarker string `json:"seenMarker"`
Items []SearchResultItem `json:"items"`
}
func (c *Client) PollRange(ctx context.Context, b Bucket, pk string, q PollRangeQuery, timeout time.Duration) (*PollRangeResponse, error) {
u, err := url.Parse(c.endpoint)
if err != nil {
return nil, err
}
u.Path = string(b) + "/" + pk
query := make(url.Values)
query.Set("poll_range", "")
u.RawQuery = query.Encode()
reqBody, err := json.Marshal(struct {
PollRangeQuery
Timeout int `json:"timeout,omitempty"`
}{
PollRangeQuery: q,
Timeout: int(timeout.Seconds()),
})
if err != nil {
return nil, err
}
req, err := http.NewRequestWithContext(ctx, "SEARCH", u.String(), bytes.NewReader(reqBody))
if err != nil {
return nil, err
}
resp, err := c.executeRequest(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, err
}
switch resp.StatusCode {
case http.StatusOK:
break
case http.StatusNotModified:
return nil, NotModifiedTimeoutErr
default:
return nil, fmt.Errorf("http status code %d: %s", resp.StatusCode, body)
}
var result PollRangeResponse
if err := json.Unmarshal(body, &result); err != nil {
return nil, err
}
return &result, nil
}

118
poll_batch_test.go Normal file
View file

@ -0,0 +1,118 @@
package k2v_test
import (
k2v "code.notaphish.fyi/milas/garage-k2v-go"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"net/http/httptrace"
"strconv"
"testing"
"time"
)
func TestClient_PollRange(t *testing.T) {
t.Parallel()
f, ctx := newFixture(t)
pk := randomPk()
sk := randomSk()
for i := range 5 {
err := f.cli.InsertItem(ctx, f.bucket, pk, sk+"-"+strconv.Itoa(i), "", []byte("hello1"))
require.NoError(t, err)
}
// first read should complete immediately
q := k2v.PollRangeQuery{
Start: sk,
}
result, err := f.cli.PollRange(ctx, f.bucket, pk, q, 5*time.Second)
require.NoError(t, err)
require.NotEmpty(t, result.SeenMarker)
require.Len(t, result.Items, 5)
for i := range result.Items {
require.Len(t, result.Items[i].Values, 1)
require.Equal(t, "hello1", string(result.Items[i].Values[0]))
}
updateErrCh := make(chan error, 1)
pollReadyCh := make(chan struct{})
go func(sk string, ct k2v.CausalityToken) {
defer close(updateErrCh)
select {
case <-ctx.Done():
t.Errorf("Context canceled: %v", ctx.Err())
return
case <-pollReadyCh:
t.Logf("PollRange connected")
}
updateErrCh <- f.cli.InsertItem(ctx, f.bucket, pk, sk, ct, []byte("hello2"))
}(result.Items[3].SortKey, k2v.CausalityToken(result.Items[3].CausalityToken))
trace := &httptrace.ClientTrace{
WroteRequest: func(_ httptrace.WroteRequestInfo) {
pollReadyCh <- struct{}{}
},
}
q.SeenMarker = result.SeenMarker
result, err = f.cli.PollRange(httptrace.WithClientTrace(ctx, trace), f.bucket, pk, q, 5*time.Second)
if assert.NoError(t, err) {
require.NotEmpty(t, result.SeenMarker)
require.Len(t, result.Items, 1)
require.Len(t, result.Items[0].Values, 1)
require.Equal(t, sk+"-3", result.Items[0].SortKey)
require.Equal(t, "hello2", string(result.Items[0].Values[0]))
}
require.NoError(t, <-updateErrCh)
require.NoError(t, err)
require.NoError(t, f.cli.InsertItem(ctx, f.bucket, pk, result.Items[0].SortKey, k2v.CausalityToken(result.Items[0].CausalityToken), []byte("hello3")))
q.SeenMarker = result.SeenMarker
result, err = f.cli.PollRange(ctx, f.bucket, pk, q, 5*time.Second)
if assert.NoError(t, err) {
require.NotEmpty(t, result.SeenMarker)
require.Len(t, result.Items, 1)
require.Len(t, result.Items[0].Values, 1)
require.Equal(t, sk+"-3", result.Items[0].SortKey)
require.Equal(t, "hello3", string(result.Items[0].Values[0]))
}
}
func TestClient_PollRange_Timeout(t *testing.T) {
if testing.Short() {
t.Skip("Skipping in short mode: 1 sec minimum to trigger timeout")
return
}
t.Parallel()
f, ctx := newFixture(t)
pk := randomPk()
sk := randomSk()
for i := range 5 {
err := f.cli.InsertItem(ctx, f.bucket, pk, sk+"-"+strconv.Itoa(i), "", []byte("hello1"))
require.NoError(t, err)
}
// first read should complete immediately
q := k2v.PollRangeQuery{
Start: sk,
}
result, err := f.cli.PollRange(ctx, f.bucket, pk, q, 5*time.Second)
require.NoError(t, err)
require.NotEmpty(t, result.SeenMarker)
require.Len(t, result.Items, 5)
for i := range result.Items {
require.Len(t, result.Items[i].Values, 1)
require.Equal(t, "hello1", string(result.Items[i].Values[0]))
}
q.SeenMarker = result.SeenMarker
result, err = f.cli.PollRange(ctx, f.bucket, pk, q, 1*time.Second)
require.ErrorIs(t, err, k2v.NotModifiedTimeoutErr)
require.Nil(t, result)
}

72
poll_single_test.go Normal file
View file

@ -0,0 +1,72 @@
package k2v_test
import (
k2v "code.notaphish.fyi/milas/garage-k2v-go"
"github.com/stretchr/testify/require"
"net/http/httptrace"
"testing"
"time"
)
func TestClient_PollItem(t *testing.T) {
t.Parallel()
f, ctx := newFixture(t)
pk := randomPk()
sk := randomSk()
err := f.cli.InsertItem(ctx, f.bucket, pk, sk, "", []byte("hello1"))
require.NoError(t, err)
_, ct, err := f.cli.ReadItemSingle(ctx, f.bucket, pk, sk)
require.NoError(t, err)
updateErrCh := make(chan error, 1)
pollReadyCh := make(chan struct{})
go func(ct k2v.CausalityToken) {
defer close(updateErrCh)
select {
case <-ctx.Done():
t.Errorf("Context canceled: %v", ctx.Err())
return
case <-pollReadyCh:
t.Logf("PollItem connected")
}
updateErrCh <- f.cli.InsertItem(ctx, f.bucket, pk, sk, ct, []byte("hello2"))
}(ct)
trace := &httptrace.ClientTrace{
WroteRequest: func(_ httptrace.WroteRequestInfo) {
pollReadyCh <- struct{}{}
},
}
item, ct, err := f.cli.PollItem(httptrace.WithClientTrace(ctx, trace), f.bucket, pk, sk, ct, 5*time.Second)
require.NoError(t, err)
require.Equal(t, "hello2", string(item))
require.NotEmpty(t, ct)
require.NoError(t, <-updateErrCh)
}
func TestClient_PollItem_Timeout(t *testing.T) {
if testing.Short() {
t.Skip("Skipping in short mode: 1 sec minimum to trigger timeout")
return
}
t.Parallel()
f, ctx := newFixture(t)
pk := randomPk()
sk := randomSk()
err := f.cli.InsertItem(ctx, f.bucket, pk, sk, "", []byte("hello1"))
require.NoError(t, err)
_, ct, err := f.cli.ReadItemSingle(ctx, f.bucket, pk, sk)
require.NoError(t, err)
item, _, err := f.cli.PollItem(ctx, f.bucket, pk, sk, ct, 1*time.Second)
require.ErrorIs(t, err, k2v.NotModifiedTimeoutErr)
require.Empty(t, item)
}

View file

@ -10,7 +10,7 @@ import (
"net/url" "net/url"
) )
type ReadBatchSearch struct { type BatchSearch struct {
PartitionKey string `json:"partitionKey"` PartitionKey string `json:"partitionKey"`
// Prefix restricts listing to partition keys that start with this value. // Prefix restricts listing to partition keys that start with this value.
@ -54,12 +54,12 @@ type BatchSearchResult struct {
} }
type SearchResultItem struct { type SearchResultItem struct {
SortKey string `json:"sk"` SortKey string `json:"sk"`
CausalityToken string `json:"ct"` CausalityToken CausalityToken `json:"ct"`
Values []Item `json:"v"` Values []Item `json:"v"`
} }
func (c *Client) ReadBatch(ctx context.Context, b Bucket, q []ReadBatchSearch) ([]BatchSearchResult, error) { func (c *Client) ReadBatch(ctx context.Context, b Bucket, q []BatchSearch) ([]*BatchSearchResult, error) {
u, err := url.Parse(c.endpoint) u, err := url.Parse(c.endpoint)
if err != nil { if err != nil {
return nil, err return nil, err
@ -91,7 +91,7 @@ func (c *Client) ReadBatch(ctx context.Context, b Bucket, q []ReadBatchSearch) (
return nil, fmt.Errorf("http status code %d: %s", resp.StatusCode, body) return nil, fmt.Errorf("http status code %d: %s", resp.StatusCode, body)
} }
var items []BatchSearchResult var items []*BatchSearchResult
if err := json.Unmarshal(body, &items); err != nil { if err := json.Unmarshal(body, &items); err != nil {
return nil, err return nil, err
} }
@ -112,9 +112,9 @@ type BulkGetItem struct {
} }
func BulkGet(ctx context.Context, cli *Client, b Bucket, keys []ItemKey) ([]BulkGetItem, error) { func BulkGet(ctx context.Context, cli *Client, b Bucket, keys []ItemKey) ([]BulkGetItem, error) {
q := make([]ReadBatchSearch, len(keys)) q := make([]BatchSearch, len(keys))
for i := range keys { for i := range keys {
q[i] = ReadBatchSearch{ q[i] = BatchSearch{
PartitionKey: keys[i].PartitionKey, PartitionKey: keys[i].PartitionKey,
Start: keys[i].SortKey, Start: keys[i].SortKey,
SingleItem: true, SingleItem: true,

View file

@ -2,52 +2,79 @@ package k2v_test
import ( import (
k2v "code.notaphish.fyi/milas/garage-k2v-go" k2v "code.notaphish.fyi/milas/garage-k2v-go"
"github.com/davecgh/go-spew/spew" "fmt"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"math/rand/v2" "math/rand/v2"
"strconv" "strconv"
"strings"
"testing" "testing"
) )
func TestClient_ReadBatch(t *testing.T) { func TestClient_ReadBatch(t *testing.T) {
f, ctx := newFixture(t) f, ctx := newFixture(t)
pk1 := randomKey() pk1 := randomKey("pk1")
sk1 := randomKey() sk1 := randomKey("sk1")
require.NoError(t, f.cli.InsertItem(ctx, f.bucket, pk1, sk1, "", []byte(strings.Join([]string{"hello", pk1, sk1}, "-"))))
require.NoError(t, f.cli.InsertItem(ctx, f.bucket, pk1, sk1, "", []byte("hello"))) pk2 := randomKey("pk2")
sk2 := randomKey("sk2")
pk2 := randomKey()
for i := range 5 { for i := range 5 {
sk := randomKey() require.NoError(t, f.cli.InsertItem(ctx, f.bucket, pk2, sk2, "", []byte(strings.Join([]string{"hello", pk2, sk2, strconv.Itoa(i)}, "-"))))
require.NoError(t, f.cli.InsertItem(ctx, f.bucket, pk2, sk, "", []byte("hello-"+strconv.Itoa(i))))
} }
pk3 := randomKey() pk3 := randomKey("pk3")
sk3 := randomKey()
for i := range 5 { for i := range 5 {
require.NoError(t, f.cli.InsertItem(ctx, f.bucket, pk3, sk3, "", []byte("hello-"+strconv.Itoa(i)))) skN := randomKey(fmt.Sprintf("sk%d", i+3))
require.NoError(t, f.cli.InsertItem(ctx, f.bucket, pk3, skN, "", []byte(strings.Join([]string{"hello", pk3, skN, strconv.Itoa(i)}, "-"))))
} }
q := []k2v.ReadBatchSearch{ q := []k2v.BatchSearch{
{ {
PartitionKey: pk1, PartitionKey: pk1,
}, },
{ {
PartitionKey: pk2, PartitionKey: pk2,
SingleItem: true,
Start: sk2,
}, },
{ {
PartitionKey: pk3, PartitionKey: pk3,
SingleItem: true,
Start: sk3,
}, },
} }
items, err := f.cli.ReadBatch(ctx, f.bucket, q) results, err := f.cli.ReadBatch(ctx, f.bucket, q)
require.NoError(t, err) require.NoError(t, err)
require.NotEmpty(t, items) require.NotEmpty(t, results)
require.Len(t, results, 3)
spew.Dump(items) assert.Equal(t, pk1, results[0].PartitionKey)
if assert.Len(t, results[0].Items, 1) && assert.Len(t, results[0].Items[0].Values, 1) {
assert.Equal(t, sk1, results[0].Items[0].SortKey)
assert.NotEmpty(t, results[0].Items[0].CausalityToken)
assert.Contains(t, results[0].Items[0].Values[0].GoString(), "hello")
}
assert.Equal(t, pk2, results[1].PartitionKey)
if assert.Len(t, results[1].Items, 1) && assert.Len(t, results[1].Items[0].Values, 5) {
assert.Equal(t, sk2, results[1].Items[0].SortKey)
assert.NotEmpty(t, results[1].Items[0].CausalityToken)
for i := range results[1].Items[0].Values {
assert.Contains(t, results[1].Items[0].Values[i].GoString(), "hello")
}
}
assert.Equal(t, pk3, results[2].PartitionKey)
if assert.Len(t, results[2].Items, 5) {
for _, item := range results[2].Items {
assert.NotEmpty(t, item.SortKey)
assert.NotEmpty(t, item.CausalityToken)
if assert.Len(t, item.Values, 1) {
assert.Contains(t, item.Values[0].GoString(), "hello")
}
}
}
} }
func TestBulkGet(t *testing.T) { func TestBulkGet(t *testing.T) {
@ -56,8 +83,8 @@ func TestBulkGet(t *testing.T) {
keys := make([]k2v.ItemKey, 500) keys := make([]k2v.ItemKey, 500)
for i := range keys { for i := range keys {
keys[i] = k2v.ItemKey{ keys[i] = k2v.ItemKey{
PartitionKey: randomKey(), PartitionKey: randomPk(),
SortKey: randomKey(), SortKey: randomSk(),
} }
require.NoError(t, f.cli.InsertItem(ctx, f.bucket, keys[i].PartitionKey, keys[i].SortKey, "", []byte("hello"+strconv.Itoa(i)))) require.NoError(t, f.cli.InsertItem(ctx, f.bucket, keys[i].PartitionKey, keys[i].SortKey, "", []byte("hello"+strconv.Itoa(i))))
} }