Compare commits

...

2 commits
v0.1.1 ... main

8 changed files with 100 additions and 16 deletions

View file

@ -19,9 +19,13 @@ import (
const CausalityTokenHeader = "X-Garage-Causality-Token" const CausalityTokenHeader = "X-Garage-Causality-Token"
const payloadHashEmpty = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
const payloadHashUnsigned = "UNSIGNED-PAYLOAD"
var TombstoneItemErr = errors.New("item is a tombstone") var TombstoneItemErr = errors.New("item is a tombstone")
var NoSuchItemErr = errors.New("item does not exist") var NoSuchItemErr = errors.New("item does not exist")
var ConcurrentItemsErr = errors.New("item has multiple concurrent values") var ConcurrentItemsErr = errors.New("item has multiple concurrent values")
var NotModifiedTimeoutErr = errors.New("not modified within timeout")
var awsSigner = v4.NewSigner() var awsSigner = v4.NewSigner()
@ -94,7 +98,7 @@ func (c *Client) executeRequest(req *http.Request) (*http.Response, error) {
return nil, err return nil, err
} }
resp, err := http.DefaultClient.Do(req) resp, err := c.httpClient.Do(req)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -103,14 +107,25 @@ func (c *Client) executeRequest(req *http.Request) (*http.Response, error) {
} }
func (c *Client) signRequest(req *http.Request) error { func (c *Client) signRequest(req *http.Request) error {
if c.key.ID == "" || c.key.Secret == "" {
return errors.New("no credentials provided")
}
creds := aws.Credentials{ creds := aws.Credentials{
AccessKeyID: c.key.ID, AccessKeyID: c.key.ID,
SecretAccessKey: c.key.Secret, SecretAccessKey: c.key.Secret,
} }
const noBody = "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
req.Header.Set("X-Amz-Content-Sha256", noBody)
err := awsSigner.SignHTTP(req.Context(), creds, req, noBody, "k2v", "garage", time.Now()) var payloadHash string
if req.Body == nil || req.Body == http.NoBody {
payloadHash = payloadHashEmpty
} else {
payloadHash = payloadHashUnsigned
}
req.Header.Set("X-Amz-Content-Sha256", payloadHash)
err := awsSigner.SignHTTP(req.Context(), creds, req, payloadHash, "k2v", "garage", time.Now())
if err != nil { if err != nil {
return err return err
} }
@ -247,9 +262,6 @@ func (c *Client) ReadItemMulti(ctx context.Context, b Bucket, pk string, sk stri
return []Item{body}, ct, nil return []Item{body}, ct, nil
case "application/json": case "application/json":
var items []Item var items []Item
if err != nil {
return nil, "", err
}
if err := json.Unmarshal(body, &items); err != nil { if err := json.Unmarshal(body, &items); err != nil {
return nil, "", err return nil, "", err
} }
@ -300,6 +312,8 @@ func (c *Client) readItemSingle(ctx context.Context, b Bucket, pk string, sk str
return nil, "", NoSuchItemErr return nil, "", NoSuchItemErr
case http.StatusConflict: case http.StatusConflict:
return nil, ct, ConcurrentItemsErr return nil, ct, ConcurrentItemsErr
case http.StatusNotModified:
return nil, "", NotModifiedTimeoutErr
default: default:
return nil, "", fmt.Errorf("http status code %d", resp.StatusCode) return nil, "", fmt.Errorf("http status code %d", resp.StatusCode)
} }

View file

@ -27,7 +27,7 @@ func newFixture(t testing.TB) (*fixture, context.Context) {
t: t, t: t,
ctx: ctx, ctx: ctx,
cli: cli, cli: cli,
bucket: k2v.Bucket("k2v-test"), bucket: TestBucket,
} }
return f, ctx return f, ctx

View file

@ -1,10 +1,16 @@
package k2v package k2v_test
import ( import (
k2v "code.notaphish.fyi/milas/garage-k2v-go"
"go.uber.org/goleak" "go.uber.org/goleak"
"os"
"testing" "testing"
) )
const BucketEnvVar = "K2V_TEST_BUCKET"
var TestBucket = k2v.Bucket(os.Getenv(BucketEnvVar))
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
goleak.VerifyTestMain(m) goleak.VerifyTestMain(m)
} }

View file

@ -32,15 +32,14 @@ func ExampleScrollIndex() {
ctx := context.Background() ctx := context.Background()
client := k2v.NewClient(k2v.EndpointFromEnv(), k2v.KeyFromEnv()) client := k2v.NewClient(k2v.EndpointFromEnv(), k2v.KeyFromEnv())
defer client.Close() defer client.Close()
const bucket = "k2v-test"
pkPrefix := randomPk() pkPrefix := randomPk()
for i := range 5 { for i := range 5 {
_ = client.InsertItem(ctx, bucket, pkPrefix+"-"+strconv.Itoa(i), randomSk(), "", []byte("hello")) _ = client.InsertItem(ctx, TestBucket, pkPrefix+"-"+strconv.Itoa(i), randomSk(), "", []byte("hello"))
} }
var responses []*k2v.ReadIndexResponse var responses []*k2v.ReadIndexResponse
_ = k2v.ScrollIndex(ctx, client, bucket, k2v.ReadIndexQuery{Prefix: pkPrefix, Limit: 25}, func(resp *k2v.ReadIndexResponse) error { _ = k2v.ScrollIndex(ctx, client, TestBucket, k2v.ReadIndexQuery{Prefix: pkPrefix, Limit: 25}, func(resp *k2v.ReadIndexResponse) error {
responses = append(responses, resp) responses = append(responses, resp)
return nil return nil
}) })

View file

@ -67,7 +67,12 @@ func (c *Client) PollRange(ctx context.Context, b Bucket, pk string, q PollRange
return nil, err return nil, err
} }
if resp.StatusCode != http.StatusOK { switch resp.StatusCode {
case http.StatusOK:
break
case http.StatusNotModified:
return nil, NotModifiedTimeoutErr
default:
return nil, fmt.Errorf("http status code %d: %s", resp.StatusCode, body) return nil, fmt.Errorf("http status code %d: %s", resp.StatusCode, body)
} }

View file

@ -80,3 +80,39 @@ func TestClient_PollRange(t *testing.T) {
require.Equal(t, "hello3", string(result.Items[0].Values[0])) require.Equal(t, "hello3", string(result.Items[0].Values[0]))
} }
} }
func TestClient_PollRange_Timeout(t *testing.T) {
if testing.Short() {
t.Skip("Skipping in short mode: 1 sec minimum to trigger timeout")
return
}
t.Parallel()
f, ctx := newFixture(t)
pk := randomPk()
sk := randomSk()
for i := range 5 {
err := f.cli.InsertItem(ctx, f.bucket, pk, sk+"-"+strconv.Itoa(i), "", []byte("hello1"))
require.NoError(t, err)
}
// first read should complete immediately
q := k2v.PollRangeQuery{
Start: sk,
}
result, err := f.cli.PollRange(ctx, f.bucket, pk, q, 5*time.Second)
require.NoError(t, err)
require.NotEmpty(t, result.SeenMarker)
require.Len(t, result.Items, 5)
for i := range result.Items {
require.Len(t, result.Items[i].Values, 1)
require.Equal(t, "hello1", string(result.Items[i].Values[0]))
}
q.SeenMarker = result.SeenMarker
result, err = f.cli.PollRange(ctx, f.bucket, pk, q, 1*time.Second)
require.ErrorIs(t, err, k2v.NotModifiedTimeoutErr)
require.Nil(t, result)
}

View file

@ -20,6 +20,7 @@ func TestClient_PollItem(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
_, ct, err := f.cli.ReadItemSingle(ctx, f.bucket, pk, sk) _, ct, err := f.cli.ReadItemSingle(ctx, f.bucket, pk, sk)
require.NoError(t, err)
updateErrCh := make(chan error, 1) updateErrCh := make(chan error, 1)
pollReadyCh := make(chan struct{}) pollReadyCh := make(chan struct{})
@ -46,3 +47,26 @@ func TestClient_PollItem(t *testing.T) {
require.NotEmpty(t, ct) require.NotEmpty(t, ct)
require.NoError(t, <-updateErrCh) require.NoError(t, <-updateErrCh)
} }
func TestClient_PollItem_Timeout(t *testing.T) {
if testing.Short() {
t.Skip("Skipping in short mode: 1 sec minimum to trigger timeout")
return
}
t.Parallel()
f, ctx := newFixture(t)
pk := randomPk()
sk := randomSk()
err := f.cli.InsertItem(ctx, f.bucket, pk, sk, "", []byte("hello1"))
require.NoError(t, err)
_, ct, err := f.cli.ReadItemSingle(ctx, f.bucket, pk, sk)
require.NoError(t, err)
item, _, err := f.cli.PollItem(ctx, f.bucket, pk, sk, ct, 1*time.Second)
require.ErrorIs(t, err, k2v.NotModifiedTimeoutErr)
require.Empty(t, item)
}

View file

@ -55,7 +55,7 @@ type BatchSearchResult struct {
type SearchResultItem struct { type SearchResultItem struct {
SortKey string `json:"sk"` SortKey string `json:"sk"`
CausalityToken string `json:"ct"` CausalityToken CausalityToken `json:"ct"`
Values []Item `json:"v"` Values []Item `json:"v"`
} }