Files
khatru/storage/postgresql/query.go
bndw ae641fd24d refactor(postgres): make SQL generation testable
Decouples the postgresql sql generation from the query execution.
This allows the logic for building sql to be unit tested without
access to a database.

This work was motivated when a client was not receiving events as
expected. In debugging I found that if a tag's value was an empty array,
then no query would be executed - and to my surprised no error is
raised either. I wanted to get a better sense of the current constraints on
when queries are and are not executed, but I had a hard time keeping the
code in my head. This led me to extracting the sql generation into its
own function and writing the unit tests that document its current
behavior. This refactor makes no changes to the current logic. I have added
some REVIEW comments in the test cases where I thought some error handling
could be introduced but I wanted to first see if you were receptive to this
refactor before proposing any functional changes.
2023-05-02 12:41:51 -03:00

176 lines
4.2 KiB
Go

package postgresql
import (
"context"
"database/sql"
"encoding/hex"
"fmt"
"strconv"
"strings"
"github.com/jmoiron/sqlx"
"github.com/nbd-wtf/go-nostr"
)
func (b PostgresBackend) QueryEvents(ctx context.Context, filter *nostr.Filter) (ch chan *nostr.Event, err error) {
ch = make(chan *nostr.Event)
query, params, err := queryEventsSql(filter)
if err != nil {
return nil, err
}
rows, err := b.DB.Query(query, params...)
if err != nil && err != sql.ErrNoRows {
return nil, fmt.Errorf("failed to fetch events using query %q: %w", query, err)
}
go func() {
defer rows.Close()
defer close(ch)
for rows.Next() {
var evt nostr.Event
var timestamp int64
err := rows.Scan(&evt.ID, &evt.PubKey, &timestamp,
&evt.Kind, &evt.Tags, &evt.Content, &evt.Sig)
if err != nil {
return
}
evt.CreatedAt = nostr.Timestamp(timestamp)
ch <- &evt
}
}()
return ch, nil
}
func queryEventsSql(filter *nostr.Filter) (string, []any, error) {
var conditions []string
var params []any
if filter == nil {
return "", nil, fmt.Errorf("filter cannot be null")
}
if filter.IDs != nil {
if len(filter.IDs) > 500 {
// too many ids, fail everything
return "", nil, nil
}
likeids := make([]string, 0, len(filter.IDs))
for _, id := range filter.IDs {
// to prevent sql attack here we will check if
// these ids are valid 32byte hex
parsed, err := hex.DecodeString(id)
if err != nil || len(parsed) != 32 {
continue
}
likeids = append(likeids, fmt.Sprintf("id LIKE '%x%%'", parsed))
}
if len(likeids) == 0 {
// ids being [] mean you won't get anything
return "", nil, nil
}
conditions = append(conditions, "("+strings.Join(likeids, " OR ")+")")
}
if filter.Authors != nil {
if len(filter.Authors) > 500 {
// too many authors, fail everything
return "", nil, nil
}
likekeys := make([]string, 0, len(filter.Authors))
for _, key := range filter.Authors {
// to prevent sql attack here we will check if
// these keys are valid 32byte hex
parsed, err := hex.DecodeString(key)
if err != nil || len(parsed) != 32 {
continue
}
likekeys = append(likekeys, fmt.Sprintf("pubkey LIKE '%x%%'", parsed))
}
if len(likekeys) == 0 {
// authors being [] mean you won't get anything
return "", nil, nil
}
conditions = append(conditions, "("+strings.Join(likekeys, " OR ")+")")
}
if filter.Kinds != nil {
if len(filter.Kinds) > 10 {
// too many kinds, fail everything
return "", nil, nil
}
if len(filter.Kinds) == 0 {
// kinds being [] mean you won't get anything
return "", nil, nil
}
// no sql injection issues since these are ints
inkinds := make([]string, len(filter.Kinds))
for i, kind := range filter.Kinds {
inkinds[i] = strconv.Itoa(kind)
}
conditions = append(conditions, `kind IN (`+strings.Join(inkinds, ",")+`)`)
}
tagQuery := make([]string, 0, 1)
for _, values := range filter.Tags {
if len(values) == 0 {
// any tag set to [] is wrong
return "", nil, nil
}
// add these tags to the query
tagQuery = append(tagQuery, values...)
if len(tagQuery) > 10 {
// too many tags, fail everything
return "", nil, nil
}
}
if len(tagQuery) > 0 {
arrayBuild := make([]string, len(tagQuery))
for i, tagValue := range tagQuery {
arrayBuild[i] = "?"
params = append(params, tagValue)
}
// we use a very bad implementation in which we only check the tag values and
// ignore the tag names
conditions = append(conditions,
"tagvalues && ARRAY["+strings.Join(arrayBuild, ",")+"]")
}
if filter.Since != nil {
conditions = append(conditions, "created_at > ?")
params = append(params, filter.Since)
}
if filter.Until != nil {
conditions = append(conditions, "created_at < ?")
params = append(params, filter.Until)
}
if len(conditions) == 0 {
// fallback
conditions = append(conditions, "true")
}
if filter.Limit < 1 || filter.Limit > 100 {
params = append(params, 100)
} else {
params = append(params, filter.Limit)
}
query := sqlx.Rebind(sqlx.BindType("postgres"), `SELECT
id, pubkey, created_at, kind, tags, content, sig
FROM event WHERE `+
strings.Join(conditions, " AND ")+
" ORDER BY created_at DESC LIMIT ?")
return query, params, nil
}