mirror of
https://github.com/aljazceru/khatru.git
synced 2026-01-12 09:54:21 +01:00
Decouples the postgresql sql generation from the query execution. This allows the logic for building sql to be unit tested without access to a database. This work was motivated when a client was not receiving events as expected. In debugging I found that if a tag's value was an empty array, then no query would be executed - and to my surprised no error is raised either. I wanted to get a better sense of the current constraints on when queries are and are not executed, but I had a hard time keeping the code in my head. This led me to extracting the sql generation into its own function and writing the unit tests that document its current behavior. This refactor makes no changes to the current logic. I have added some REVIEW comments in the test cases where I thought some error handling could be introduced but I wanted to first see if you were receptive to this refactor before proposing any functional changes.
176 lines
4.2 KiB
Go
176 lines
4.2 KiB
Go
package postgresql
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/jmoiron/sqlx"
|
|
"github.com/nbd-wtf/go-nostr"
|
|
)
|
|
|
|
func (b PostgresBackend) QueryEvents(ctx context.Context, filter *nostr.Filter) (ch chan *nostr.Event, err error) {
|
|
ch = make(chan *nostr.Event)
|
|
|
|
query, params, err := queryEventsSql(filter)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
rows, err := b.DB.Query(query, params...)
|
|
if err != nil && err != sql.ErrNoRows {
|
|
return nil, fmt.Errorf("failed to fetch events using query %q: %w", query, err)
|
|
}
|
|
|
|
go func() {
|
|
defer rows.Close()
|
|
defer close(ch)
|
|
for rows.Next() {
|
|
var evt nostr.Event
|
|
var timestamp int64
|
|
err := rows.Scan(&evt.ID, &evt.PubKey, ×tamp,
|
|
&evt.Kind, &evt.Tags, &evt.Content, &evt.Sig)
|
|
if err != nil {
|
|
return
|
|
}
|
|
evt.CreatedAt = nostr.Timestamp(timestamp)
|
|
ch <- &evt
|
|
}
|
|
}()
|
|
|
|
return ch, nil
|
|
}
|
|
|
|
func queryEventsSql(filter *nostr.Filter) (string, []any, error) {
|
|
var conditions []string
|
|
var params []any
|
|
|
|
if filter == nil {
|
|
return "", nil, fmt.Errorf("filter cannot be null")
|
|
}
|
|
|
|
if filter.IDs != nil {
|
|
if len(filter.IDs) > 500 {
|
|
// too many ids, fail everything
|
|
return "", nil, nil
|
|
}
|
|
|
|
likeids := make([]string, 0, len(filter.IDs))
|
|
for _, id := range filter.IDs {
|
|
// to prevent sql attack here we will check if
|
|
// these ids are valid 32byte hex
|
|
parsed, err := hex.DecodeString(id)
|
|
if err != nil || len(parsed) != 32 {
|
|
continue
|
|
}
|
|
likeids = append(likeids, fmt.Sprintf("id LIKE '%x%%'", parsed))
|
|
}
|
|
if len(likeids) == 0 {
|
|
// ids being [] mean you won't get anything
|
|
return "", nil, nil
|
|
}
|
|
conditions = append(conditions, "("+strings.Join(likeids, " OR ")+")")
|
|
}
|
|
|
|
if filter.Authors != nil {
|
|
if len(filter.Authors) > 500 {
|
|
// too many authors, fail everything
|
|
return "", nil, nil
|
|
}
|
|
|
|
likekeys := make([]string, 0, len(filter.Authors))
|
|
for _, key := range filter.Authors {
|
|
// to prevent sql attack here we will check if
|
|
// these keys are valid 32byte hex
|
|
parsed, err := hex.DecodeString(key)
|
|
if err != nil || len(parsed) != 32 {
|
|
continue
|
|
}
|
|
likekeys = append(likekeys, fmt.Sprintf("pubkey LIKE '%x%%'", parsed))
|
|
}
|
|
if len(likekeys) == 0 {
|
|
// authors being [] mean you won't get anything
|
|
return "", nil, nil
|
|
}
|
|
conditions = append(conditions, "("+strings.Join(likekeys, " OR ")+")")
|
|
}
|
|
|
|
if filter.Kinds != nil {
|
|
if len(filter.Kinds) > 10 {
|
|
// too many kinds, fail everything
|
|
return "", nil, nil
|
|
}
|
|
|
|
if len(filter.Kinds) == 0 {
|
|
// kinds being [] mean you won't get anything
|
|
return "", nil, nil
|
|
}
|
|
// no sql injection issues since these are ints
|
|
inkinds := make([]string, len(filter.Kinds))
|
|
for i, kind := range filter.Kinds {
|
|
inkinds[i] = strconv.Itoa(kind)
|
|
}
|
|
conditions = append(conditions, `kind IN (`+strings.Join(inkinds, ",")+`)`)
|
|
}
|
|
|
|
tagQuery := make([]string, 0, 1)
|
|
for _, values := range filter.Tags {
|
|
if len(values) == 0 {
|
|
// any tag set to [] is wrong
|
|
return "", nil, nil
|
|
}
|
|
|
|
// add these tags to the query
|
|
tagQuery = append(tagQuery, values...)
|
|
|
|
if len(tagQuery) > 10 {
|
|
// too many tags, fail everything
|
|
return "", nil, nil
|
|
}
|
|
}
|
|
|
|
if len(tagQuery) > 0 {
|
|
arrayBuild := make([]string, len(tagQuery))
|
|
for i, tagValue := range tagQuery {
|
|
arrayBuild[i] = "?"
|
|
params = append(params, tagValue)
|
|
}
|
|
|
|
// we use a very bad implementation in which we only check the tag values and
|
|
// ignore the tag names
|
|
conditions = append(conditions,
|
|
"tagvalues && ARRAY["+strings.Join(arrayBuild, ",")+"]")
|
|
}
|
|
|
|
if filter.Since != nil {
|
|
conditions = append(conditions, "created_at > ?")
|
|
params = append(params, filter.Since)
|
|
}
|
|
if filter.Until != nil {
|
|
conditions = append(conditions, "created_at < ?")
|
|
params = append(params, filter.Until)
|
|
}
|
|
|
|
if len(conditions) == 0 {
|
|
// fallback
|
|
conditions = append(conditions, "true")
|
|
}
|
|
|
|
if filter.Limit < 1 || filter.Limit > 100 {
|
|
params = append(params, 100)
|
|
} else {
|
|
params = append(params, filter.Limit)
|
|
}
|
|
|
|
query := sqlx.Rebind(sqlx.BindType("postgres"), `SELECT
|
|
id, pubkey, created_at, kind, tags, content, sig
|
|
FROM event WHERE `+
|
|
strings.Join(conditions, " AND ")+
|
|
" ORDER BY created_at DESC LIMIT ?")
|
|
|
|
return query, params, nil
|
|
}
|