I'm implementing a database API. I have models. I need to implement CRUD actions to each model. For now I create an individual GetAllModels
function and Get
method for each model. How could I do that once for all models and just pass around some variables if needed?
The pattern below I use for each model:
type City struct {
Attr1 string
Attr2 string
}
type Country struct {
Attr1 string
Attr2 string
}
func GetAllCities(db *sqlx.DB) ([]*City, error) {
items := []*City{}
err := db.Select(&items, "SELECT * FROM cities")
// check err
return items, nil
}
func (m *City) Get(db *sqlx.DB, id string) error {
if err := db.Get(m, "SELECT FROM cities WHERE id = ?", id); err != nil {
return err
}
return nil
}
func GetAllCountries(db *sqlx.DB) ([]*Country, error) {
items := []*Country{}
err := db.Select(&items, "SELECT * FROM countries")
// check err
return items, nil
}
func (m *Country) Get(db *sqlx.DB, id string) error {
if err := db.Get(m, "SELECT FROM countries WHERE id = ?", id); err != nil {
return err
}
return nil
}
But what actually changes from model to model is a query string and a type of slice objects.
How to make one universal GetAll
function and Get
method for all future models?
I haven't tested this, but it should work. If you receive the table and column names (or you know them) for each request and you can just fetch from the db as interface
passing the table and column as variables.
package main
import (
"errors"
"fmt"
"github.com/jmoiron/sqlx"
)
func selectAll(table string) string {
return fmt.Sprintf("SELECT * FROM %s", table)
}
func selectWhere(table string, column string) string {
return fmt.Sprintf("SELECT * FROM %s WHERE %s = ?", table, column)
}
func validTableStr(table string) bool {
// real validation here
return true
}
func validColumnStr(table string, column string) bool {
// real validation here
return true
}
func GetAll(db *sqlx.DB, table string) ([]interface{}, error) {
if !validTableStr(table) {
return nil, errors.New("invalid table name")
}
items := []interface{}{}
err := db.Select(&items, selectAll(table))
return items, err
}
func GetWhere(db *sqlx.DB, table string, column string, value interface{}) ([]interface{}, error) {
if !validTableStr(table) {
return nil, errors.New("invalid table name")
}
if !validColumnStr(table, column) {
return nil, errors.New("invalid column name")
}
items := []interface{}{}
err := db.Select(&items, selectWhere(table, column), value)
return items, err
}