mirror of
https://github.com/ergochat/ergo.git
synced 2025-12-21 10:31:59 -08:00
User persistence to sqlite.
This commit is contained in:
parent
48ca57c43d
commit
ccdf7779a5
12 changed files with 172 additions and 115 deletions
|
|
@ -27,7 +27,9 @@ type Queryable interface {
|
|||
QueryRow(string, ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
type TransactionFunc func(Queryable) bool
|
||||
type Savable interface {
|
||||
Save(q Queryable) bool
|
||||
}
|
||||
|
||||
//
|
||||
// general
|
||||
|
|
@ -36,7 +38,7 @@ type TransactionFunc func(Queryable) bool
|
|||
func NewDatabase() *Database {
|
||||
db, err := sql.Open("sqlite3", "ergonomadic.db")
|
||||
if err != nil {
|
||||
panic("cannot open database")
|
||||
log.Fatalln("cannot open database")
|
||||
}
|
||||
return &Database{db}
|
||||
}
|
||||
|
|
@ -48,7 +50,7 @@ func NewTransaction(tx *sql.Tx) *Transaction {
|
|||
func readLines(filename string) <-chan string {
|
||||
file, err := os.Open(filename)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Fatalln(err)
|
||||
}
|
||||
reader := bufio.NewReader(file)
|
||||
lines := make(chan string)
|
||||
|
|
@ -56,7 +58,7 @@ func readLines(filename string) <-chan string {
|
|||
defer file.Close()
|
||||
defer close(lines)
|
||||
for {
|
||||
line, err := reader.ReadString('\n')
|
||||
line, err := reader.ReadString(';')
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
|
|
@ -70,28 +72,24 @@ func readLines(filename string) <-chan string {
|
|||
return lines
|
||||
}
|
||||
|
||||
func (db *Database) execSqlFile(filename string) {
|
||||
func (db *Database) ExecSqlFile(filename string) *Database {
|
||||
db.Transact(func(q Queryable) bool {
|
||||
for line := range readLines(filepath.Join("sql", filename)) {
|
||||
log.Println(line)
|
||||
q.Exec(line)
|
||||
_, err := q.Exec(line)
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *Database) InitTables() {
|
||||
db.execSqlFile("init.sql")
|
||||
}
|
||||
|
||||
func (db *Database) DropTables() {
|
||||
db.execSqlFile("drop.sql")
|
||||
}
|
||||
|
||||
func (db *Database) Transact(txf TransactionFunc) {
|
||||
func (db *Database) Transact(txf func(Queryable) bool) {
|
||||
tx, err := db.Begin()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
log.Panicln(err)
|
||||
}
|
||||
if txf(tx) {
|
||||
tx.Commit()
|
||||
|
|
@ -100,6 +98,28 @@ func (db *Database) Transact(txf TransactionFunc) {
|
|||
}
|
||||
}
|
||||
|
||||
func (db *Database) Save(s Savable) {
|
||||
db.Transact(func(tx Queryable) bool {
|
||||
return s.Save(tx)
|
||||
})
|
||||
}
|
||||
|
||||
//
|
||||
// general purpose sql
|
||||
//
|
||||
|
||||
func FindId(q Queryable, sql string, args ...interface{}) (rowId RowId, err error) {
|
||||
row := q.QueryRow(sql, args...)
|
||||
err = row.Scan(&rowId)
|
||||
return
|
||||
}
|
||||
|
||||
func Count(q Queryable, sql string, args ...interface{}) (count uint, err error) {
|
||||
row := q.QueryRow(sql, args...)
|
||||
err = row.Scan(&count)
|
||||
return
|
||||
}
|
||||
|
||||
//
|
||||
// data
|
||||
//
|
||||
|
|
@ -117,25 +137,39 @@ type ChannelRow struct {
|
|||
|
||||
// user
|
||||
|
||||
func FindUserByNick(q Queryable, nick string) (ur *UserRow) {
|
||||
ur = new(UserRow)
|
||||
row := q.QueryRow("SELECT * FROM user LIMIT 1 WHERE nick = ?", nick)
|
||||
err := row.Scan(&ur.id, &ur.nick, &ur.hash)
|
||||
func FindAllUsers(q Queryable) (urs []UserRow, err error) {
|
||||
var rows *sql.Rows
|
||||
rows, err = q.Query("SELECT id, nick, hash FROM user")
|
||||
if err != nil {
|
||||
ur = nil
|
||||
return
|
||||
}
|
||||
urs = make([]UserRow, 0)
|
||||
for rows.Next() {
|
||||
ur := UserRow{}
|
||||
err = rows.Scan(&(ur.id), &(ur.nick), &(ur.hash))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
urs = append(urs, ur)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func FindUserIdByNick(q Queryable, nick string) (rowId RowId, err error) {
|
||||
row := q.QueryRow("SELECT id FROM user WHERE nick = ?", nick)
|
||||
err = row.Scan(&rowId)
|
||||
func FindUserByNick(q Queryable, nick string) (ur *UserRow, err error) {
|
||||
ur = &UserRow{}
|
||||
row := q.QueryRow("SELECT id, nick, hash FROM user LIMIT 1 WHERE nick = ?",
|
||||
nick)
|
||||
err = row.Scan(&(ur.id), &(ur.nick), &(ur.hash))
|
||||
return
|
||||
}
|
||||
|
||||
func FindUserIdByNick(q Queryable, nick string) (RowId, error) {
|
||||
return FindId(q, "SELECT id FROM user WHERE nick = ?", nick)
|
||||
}
|
||||
|
||||
func FindChannelByName(q Queryable, name string) (cr *ChannelRow) {
|
||||
cr = new(ChannelRow)
|
||||
row := q.QueryRow("SELECT * FROM channel LIMIT 1 WHERE name = ?", name)
|
||||
row := q.QueryRow("SELECT id, name FROM channel LIMIT 1 WHERE name = ?", name)
|
||||
err := row.Scan(&(cr.id), &(cr.name))
|
||||
if err != nil {
|
||||
cr = nil
|
||||
|
|
@ -185,25 +219,31 @@ func InsertUserChannels(q Queryable, userId RowId, channelIds []RowId) (err erro
|
|||
|
||||
// channel
|
||||
|
||||
func FindChannelIdByName(q Queryable, name string) (channelId RowId, err error) {
|
||||
row := q.QueryRow("SELECT id FROM channel WHERE name = ?", name)
|
||||
err = row.Scan(&channelId)
|
||||
return
|
||||
func FindChannelIdByName(q Queryable, name string) (RowId, error) {
|
||||
return FindId(q, "SELECT id FROM channel WHERE name = ?", name)
|
||||
}
|
||||
|
||||
func FindChannelsForUser(q Queryable, userId RowId) (crs []ChannelRow) {
|
||||
rows, err := q.Query(`SELECT * FROM channel WHERE id IN
|
||||
(SELECT channel_id from user_channel WHERE user_id = ?)`, userId)
|
||||
func FindChannelsForUser(q Queryable, userId RowId) (crs []ChannelRow, err error) {
|
||||
query := ` FROM channel WHERE id IN
|
||||
(SELECT channel_id from user_channel WHERE user_id = ?)`
|
||||
count, err := Count(q, "SELECT COUNT(id)"+query, userId)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
return
|
||||
}
|
||||
crs = make([]ChannelRow, 0)
|
||||
rows, err := q.Query("SELECT id, name"+query, userId)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
crs = make([]ChannelRow, count)
|
||||
var i = 0
|
||||
for rows.Next() {
|
||||
cr := ChannelRow{}
|
||||
if err := rows.Scan(&(cr.id), &(cr.name)); err != nil {
|
||||
panic(err)
|
||||
err = rows.Scan(&(cr.id), &(cr.name))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
crs = append(crs, cr)
|
||||
crs[i] = cr
|
||||
i++
|
||||
}
|
||||
return
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue