summaryrefslogtreecommitdiffhomepage
path: root/src/dbx/db.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/dbx/db.go')
-rw-r--r--src/dbx/db.go232
1 files changed, 113 insertions, 119 deletions
diff --git a/src/dbx/db.go b/src/dbx/db.go
index fcb345e..b8112b9 100644
--- a/src/dbx/db.go
+++ b/src/dbx/db.go
@@ -1,27 +1,24 @@
package dbx
import (
- "database/sql"
"fmt"
"io/fs"
"log"
- "reflect"
"sort"
- "strings"
"git.thomasvoss.com/euro-cash.eu/pkg/atexit"
. "git.thomasvoss.com/euro-cash.eu/pkg/try"
+ "github.com/jmoiron/sqlx"
"github.com/mattn/go-sqlite3"
)
var (
- db *sql.DB
+ db *sqlx.DB
DBName string
)
func Init(sqlDir fs.FS) {
- db = Try2(sql.Open("sqlite3", DBName))
- Try(db.Ping())
+ db = sqlx.MustConnect("sqlite3", DBName)
atexit.Register(Close)
Try(applyMigrations(sqlDir))
@@ -38,7 +35,7 @@ func Init(sqlDir fs.FS) {
Password: "420",
AdminP: false,
}))
- Try2(GetMintages("ad"))
+ Try2(GetMintages("ad", TypeCirc))
}
func Close() {
@@ -49,7 +46,7 @@ func applyMigrations(dir fs.FS) error {
var latest int
migratedp := true
- rows, err := db.Query("SELECT latest FROM migration")
+ err := db.QueryRow("SELECT latest FROM migration").Scan(&latest)
if err != nil {
e, ok := err.(sqlite3.Error)
/* IDK if there is a better way to do this… lol */
@@ -58,19 +55,9 @@ func applyMigrations(dir fs.FS) error {
} else {
return err
}
- } else {
- defer rows.Close()
}
- if migratedp {
- rows.Next()
- if err := rows.Err(); err != nil {
- return err
- }
- if err := rows.Scan(&latest); err != nil {
- return err
- }
- } else {
+ if !migratedp {
latest = -1
}
@@ -104,24 +91,31 @@ func applyMigrations(dir fs.FS) error {
return err
}
- if _, err := tx.Exec(string(qry)); err != nil {
- tx.Rollback()
- return fmt.Errorf("error in ‘%s’: %w", f, err)
+ var n int
+ if _, err = fmt.Sscanf(f, "%d", &n); err != nil {
+ goto error
}
- var n int
- if _, err := fmt.Sscanf(f, "%d", &n); err != nil {
- return err
+ if _, err = tx.Exec(string(qry)); err != nil {
+ err = fmt.Errorf("error in ‘%s’: %w", f, err)
+ goto error
}
+
_, err = tx.Exec("UPDATE migration SET latest = ? WHERE id = 1", n)
if err != nil {
- return err
+ goto error
}
- if err := tx.Commit(); err != nil {
- return err
+ if err = tx.Commit(); err != nil {
+ goto error
}
+
log.Printf("Applied database migration ‘%s’\n", f)
+ continue
+
+ error:
+ tx.Rollback()
+ return err
}
if last != "" {
@@ -138,96 +132,96 @@ func applyMigrations(dir fs.FS) error {
return nil
}
-func scanToStruct[T any](rs *sql.Rows) (T, error) {
- return scanToStruct2[T](rs, true)
-}
-
-func scanToStructs[T any](rs *sql.Rows) ([]T, error) {
- xs := []T{}
- for rs.Next() {
- x, err := scanToStruct2[T](rs, false)
- if err != nil {
- return nil, err
- }
- xs = append(xs, x)
- }
- return xs, rs.Err()
-}
-
-func scanToStruct2[T any](rs *sql.Rows, callNextP bool) (T, error) {
- var t, zero T
-
- cols, err := rs.Columns()
- if err != nil {
- return zero, err
- }
-
- v := reflect.ValueOf(&t).Elem()
- tType := v.Type()
-
- rawValues := make([]any, len(cols))
- for i := range rawValues {
- var zero any
- rawValues[i] = &zero
- }
-
- if callNextP {
- rs.Next()
- if err := rs.Err(); err != nil {
- return zero, err
- }
- }
- if err := rs.Scan(rawValues...); err != nil {
- return zero, err
- }
-
- /* col idx → [field idx, array idx] */
- arrayTargets := make(map[int][2]int)
- colToField := make(map[string]int)
-
- for i := 0; i < tType.NumField(); i++ {
- field := tType.Field(i)
- tag := field.Tag.Get("db")
- if tag == "" {
- continue
- }
-
- if strings.Contains(tag, ";") {
- dbcols := strings.Split(tag, ";")
- fv := v.Field(i)
- if fv.Kind() != reflect.Array {
- return zero, fmt.Errorf("field ‘%s’ is not array",
- field.Name)
- }
- if len(dbcols) != fv.Len() {
- return zero, fmt.Errorf("field ‘%s’ array length mismatch",
- field.Name)
- }
- for j, colName := range cols {
- for k, dbColName := range dbcols {
- if colName == dbColName {
- arrayTargets[j] = [2]int{i, k}
- }
- }
- }
- } else {
- colToField[tag] = i
- }
- }
-
- for i, col := range cols {
- vp := rawValues[i].(*any)
- if fieldIdx, ok := colToField[col]; ok {
- assignValue(v.Field(fieldIdx), *vp)
- } else if target, ok := arrayTargets[i]; ok {
- assignValue(v.Field(target[0]).Index(target[1]), *vp)
- }
- }
-
- return t, nil
-}
-
-func assignValue(fv reflect.Value, val any) {
+/* func scanToStruct[T any](rs *sql.Rows) (T, error) {
+ return scanToStruct2[T](rs, true)
+ }
+
+ func scanToStructs[T any](rs *sql.Rows) ([]T, error) {
+ xs := []T{}
+ for rs.Next() {
+ x, err := scanToStruct2[T](rs, false)
+ if err != nil {
+ return nil, err
+ }
+ xs = append(xs, x)
+ }
+ return xs, rs.Err()
+ }
+
+ func scanToStruct2[T any](rs *sql.Rows, callNextP bool) (T, error) {
+ var t, zero T
+
+ cols, err := rs.Columns()
+ if err != nil {
+ return zero, err
+ }
+
+ v := reflect.ValueOf(&t).Elem()
+ tType := v.Type()
+
+ rawValues := make([]any, len(cols))
+ for i := range rawValues {
+ var zero any
+ rawValues[i] = &zero
+ }
+
+ if callNextP {
+ rs.Next()
+ if err := rs.Err(); err != nil {
+ return zero, err
+ }
+ }
+ if err := rs.Scan(rawValues...); err != nil {
+ return zero, err
+ }
+
+ /\* col idx → [field idx, array idx] *\/
+ arrayTargets := make(map[int][2]int)
+ colToField := make(map[string]int)
+
+ for i := 0; i < tType.NumField(); i++ {
+ field := tType.Field(i)
+ tag := field.Tag.Get("db")
+ if tag == "" {
+ continue
+ }
+
+ if strings.Contains(tag, ";") {
+ dbcols := strings.Split(tag, ";")
+ fv := v.Field(i)
+ if fv.Kind() != reflect.Array {
+ return zero, fmt.Errorf("field ‘%s’ is not array",
+ field.Name)
+ }
+ if len(dbcols) != fv.Len() {
+ return zero, fmt.Errorf("field ‘%s’ array length mismatch",
+ field.Name)
+ }
+ for j, colName := range cols {
+ for k, dbColName := range dbcols {
+ if colName == dbColName {
+ arrayTargets[j] = [2]int{i, k}
+ }
+ }
+ }
+ } else {
+ colToField[tag] = i
+ }
+ }
+
+ for i, col := range cols {
+ vp := rawValues[i].(*any)
+ if fieldIdx, ok := colToField[col]; ok {
+ assignValue(v.Field(fieldIdx), *vp)
+ } else if target, ok := arrayTargets[i]; ok {
+ assignValue(v.Field(target[0]).Index(target[1]), *vp)
+ }
+ }
+
+ return t, nil
+ } */
+
+/* func assignValue(fv reflect.Value, val any) {
if val == nil {
fv.Set(reflect.Zero(fv.Type()))
return
@@ -236,4 +230,4 @@ func assignValue(fv reflect.Value, val any) {
if v.Type().ConvertibleTo(fv.Type()) {
fv.Set(v.Convert(fv.Type()))
}
-}
+} */