gitea/vendor/xorm.io/xorm/session_insert.go
zeripath 1790f01dd9
Upgrade xorm to v1.2.2 (#16663) & Add test to ensure that dumping of login sources remains correct (#16847) (#16848)
* Upgrade xorm to v1.2.2 (#16663)

Backport #16663

Fix #16683

* Add test to ensure that dumping of login sources remains correct (#16847)

#16831 has occurred because of a missed regression. This PR adds a simple test to
try to prevent this occuring again.

Signed-off-by: Andrew Thornton <art27@cantab.net>

Co-authored-by: Lunny Xiao <xiaolunwen@gmail.com>
2021-08-28 13:16:19 +02:00

672 lines
17 KiB
Go
Vendored

// Copyright 2016 The Xorm Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package xorm
import (
"errors"
"fmt"
"reflect"
"sort"
"strings"
"time"
"xorm.io/xorm/convert"
"xorm.io/xorm/dialects"
"xorm.io/xorm/internal/utils"
"xorm.io/xorm/schemas"
)
// ErrNoElementsOnSlice represents an error there is no element when insert
var ErrNoElementsOnSlice = errors.New("no element on slice when insert")
// Insert insert one or more beans
func (session *Session) Insert(beans ...interface{}) (int64, error) {
var affected int64
var err error
if session.isAutoClose {
defer session.Close()
}
session.autoResetStatement = false
defer func() {
session.autoResetStatement = true
session.resetStatement()
}()
for _, bean := range beans {
var cnt int64
var err error
switch v := bean.(type) {
case map[string]interface{}:
cnt, err = session.insertMapInterface(v)
case []map[string]interface{}:
cnt, err = session.insertMultipleMapInterface(v)
case map[string]string:
cnt, err = session.insertMapString(v)
case []map[string]string:
cnt, err = session.insertMultipleMapString(v)
default:
sliceValue := reflect.Indirect(reflect.ValueOf(bean))
if sliceValue.Kind() == reflect.Slice {
cnt, err = session.insertMultipleStruct(bean)
} else {
cnt, err = session.insertStruct(bean)
}
}
if err != nil {
return affected, err
}
affected += cnt
}
return affected, err
}
func (session *Session) insertMultipleStruct(rowsSlicePtr interface{}) (int64, error) {
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
if sliceValue.Kind() != reflect.Slice {
return 0, errors.New("needs a pointer to a slice")
}
if sliceValue.Len() <= 0 {
return 0, ErrNoElementsOnSlice
}
if err := session.statement.SetRefBean(sliceValue.Index(0).Interface()); err != nil {
return 0, err
}
tableName := session.statement.TableName()
if len(tableName) <= 0 {
return 0, ErrTableNotFound
}
var (
table = session.statement.RefTable
size = sliceValue.Len()
colNames []string
colMultiPlaces []string
args []interface{}
cols []*schemas.Column
)
for i := 0; i < size; i++ {
v := sliceValue.Index(i)
var vv reflect.Value
switch v.Kind() {
case reflect.Interface:
vv = reflect.Indirect(v.Elem())
default:
vv = reflect.Indirect(v)
}
elemValue := v.Interface()
var colPlaces []string
// handle BeforeInsertProcessor
// !nashtsai! does user expect it's same slice to passed closure when using Before()/After() when insert multi??
for _, closure := range session.beforeClosures {
closure(elemValue)
}
if processor, ok := interface{}(elemValue).(BeforeInsertProcessor); ok {
processor.BeforeInsert()
}
// --
for _, col := range table.Columns() {
ptrFieldValue, err := col.ValueOfV(&vv)
if err != nil {
return 0, err
}
fieldValue := *ptrFieldValue
if col.IsAutoIncrement && utils.IsZero(fieldValue.Interface()) {
continue
}
if col.MapType == schemas.ONLYFROMDB {
continue
}
if col.IsDeleted {
continue
}
if session.statement.OmitColumnMap.Contain(col.Name) {
continue
}
if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(col.Name) {
continue
}
if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime {
val, t, err := session.engine.nowTime(col)
if err != nil {
return 0, err
}
args = append(args, val)
var colName = col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName)
setColumnTime(bean, col, t)
})
} else if col.IsVersion && session.statement.CheckVersion {
args = append(args, 1)
var colName = col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName)
setColumnInt(bean, col, 1)
})
} else {
arg, err := session.statement.Value2Interface(col, fieldValue)
if err != nil {
return 0, err
}
args = append(args, arg)
}
if i == 0 {
colNames = append(colNames, col.Name)
cols = append(cols, col)
}
colPlaces = append(colPlaces, "?")
}
colMultiPlaces = append(colMultiPlaces, strings.Join(colPlaces, ", "))
}
cleanupProcessorsClosures(&session.beforeClosures)
quoter := session.engine.dialect.Quoter()
var sql string
colStr := quoter.Join(colNames, ",")
if session.engine.dialect.URI().DBType == schemas.ORACLE {
temp := fmt.Sprintf(") INTO %s (%v) VALUES (",
quoter.Quote(tableName),
colStr)
sql = fmt.Sprintf("INSERT ALL INTO %s (%v) VALUES (%v) SELECT 1 FROM DUAL",
quoter.Quote(tableName),
colStr,
strings.Join(colMultiPlaces, temp))
} else {
sql = fmt.Sprintf("INSERT INTO %s (%v) VALUES (%v)",
quoter.Quote(tableName),
colStr,
strings.Join(colMultiPlaces, "),("))
}
res, err := session.exec(sql, args...)
if err != nil {
return 0, err
}
session.cacheInsert(tableName)
lenAfterClosures := len(session.afterClosures)
for i := 0; i < size; i++ {
elemValue := reflect.Indirect(sliceValue.Index(i)).Addr().Interface()
// handle AfterInsertProcessor
if session.isAutoCommit {
// !nashtsai! does user expect it's same slice to passed closure when using Before()/After() when insert multi??
for _, closure := range session.afterClosures {
closure(elemValue)
}
if processor, ok := elemValue.(AfterInsertProcessor); ok {
processor.AfterInsert()
}
} else {
if lenAfterClosures > 0 {
if value, has := session.afterInsertBeans[elemValue]; has && value != nil {
*value = append(*value, session.afterClosures...)
} else {
afterClosures := make([]func(interface{}), lenAfterClosures)
copy(afterClosures, session.afterClosures)
session.afterInsertBeans[elemValue] = &afterClosures
}
} else {
if _, ok := elemValue.(AfterInsertProcessor); ok {
session.afterInsertBeans[elemValue] = nil
}
}
}
}
cleanupProcessorsClosures(&session.afterClosures)
return res.RowsAffected()
}
// InsertMulti insert multiple records
func (session *Session) InsertMulti(rowsSlicePtr interface{}) (int64, error) {
if session.isAutoClose {
defer session.Close()
}
sliceValue := reflect.Indirect(reflect.ValueOf(rowsSlicePtr))
if sliceValue.Kind() != reflect.Slice {
return 0, ErrPtrSliceType
}
return session.insertMultipleStruct(rowsSlicePtr)
}
func (session *Session) insertStruct(bean interface{}) (int64, error) {
if err := session.statement.SetRefBean(bean); err != nil {
return 0, err
}
if len(session.statement.TableName()) <= 0 {
return 0, ErrTableNotFound
}
// handle BeforeInsertProcessor
for _, closure := range session.beforeClosures {
closure(bean)
}
cleanupProcessorsClosures(&session.beforeClosures) // cleanup after used
if processor, ok := interface{}(bean).(BeforeInsertProcessor); ok {
processor.BeforeInsert()
}
var tableName = session.statement.TableName()
table := session.statement.RefTable
colNames, args, err := session.genInsertColumns(bean)
if err != nil {
return 0, err
}
sqlStr, args, err := session.statement.GenInsertSQL(colNames, args)
if err != nil {
return 0, err
}
handleAfterInsertProcessorFunc := func(bean interface{}) {
if session.isAutoCommit {
for _, closure := range session.afterClosures {
closure(bean)
}
if processor, ok := interface{}(bean).(AfterInsertProcessor); ok {
processor.AfterInsert()
}
} else {
lenAfterClosures := len(session.afterClosures)
if lenAfterClosures > 0 {
if value, has := session.afterInsertBeans[bean]; has && value != nil {
*value = append(*value, session.afterClosures...)
} else {
afterClosures := make([]func(interface{}), lenAfterClosures)
copy(afterClosures, session.afterClosures)
session.afterInsertBeans[bean] = &afterClosures
}
} else {
if _, ok := interface{}(bean).(AfterInsertProcessor); ok {
session.afterInsertBeans[bean] = nil
}
}
}
cleanupProcessorsClosures(&session.afterClosures) // cleanup after used
}
// if there is auto increment column and driver don't support return it
if len(table.AutoIncrement) > 0 && !session.engine.driver.Features().SupportReturnInsertedID {
var sql = sqlStr
if session.engine.dialect.URI().DBType == schemas.ORACLE {
sql = "select seq_atable.currval from dual"
}
rows, err := session.queryRows(sql, args...)
if err != nil {
return 0, err
}
defer rows.Close()
defer handleAfterInsertProcessorFunc(bean)
session.cacheInsert(tableName)
if table.Version != "" && session.statement.CheckVersion {
verValue, err := table.VersionColumn().ValueOf(bean)
if err != nil {
session.engine.logger.Errorf("%v", err)
} else if verValue.IsValid() && verValue.CanSet() {
session.incrVersionFieldValue(verValue)
}
}
var id int64
if !rows.Next() {
if rows.Err() != nil {
return 0, rows.Err()
}
return 0, errors.New("insert successfully but not returned id")
}
if err := rows.Scan(&id); err != nil {
return 1, err
}
aiValue, err := table.AutoIncrColumn().ValueOf(bean)
if err != nil {
session.engine.logger.Errorf("%v", err)
}
if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
return 1, nil
}
return 1, convert.AssignValue(*aiValue, id)
}
res, err := session.exec(sqlStr, args...)
if err != nil {
return 0, err
}
defer handleAfterInsertProcessorFunc(bean)
session.cacheInsert(tableName)
if table.Version != "" && session.statement.CheckVersion {
verValue, err := table.VersionColumn().ValueOf(bean)
if err != nil {
session.engine.logger.Errorf("%v", err)
} else if verValue.IsValid() && verValue.CanSet() {
session.incrVersionFieldValue(verValue)
}
}
if table.AutoIncrement == "" {
return res.RowsAffected()
}
var id int64
id, err = res.LastInsertId()
if err != nil || id <= 0 {
return res.RowsAffected()
}
aiValue, err := table.AutoIncrColumn().ValueOf(bean)
if err != nil {
session.engine.logger.Errorf("%v", err)
}
if aiValue == nil || !aiValue.IsValid() || !aiValue.CanSet() {
return res.RowsAffected()
}
if err := convert.AssignValue(*aiValue, id); err != nil {
return 0, err
}
return res.RowsAffected()
}
// InsertOne insert only one struct into database as a record.
// The in parameter bean must a struct or a point to struct. The return
// parameter is inserted and error
func (session *Session) InsertOne(bean interface{}) (int64, error) {
if session.isAutoClose {
defer session.Close()
}
return session.insertStruct(bean)
}
func (session *Session) cacheInsert(table string) error {
if !session.statement.UseCache {
return nil
}
cacher := session.engine.cacherMgr.GetCacher(table)
if cacher == nil {
return nil
}
session.engine.logger.Debugf("[cache] clear SQL: %v", table)
cacher.ClearIds(table)
return nil
}
// genInsertColumns generates insert needed columns
func (session *Session) genInsertColumns(bean interface{}) ([]string, []interface{}, error) {
table := session.statement.RefTable
colNames := make([]string, 0, len(table.ColumnsSeq()))
args := make([]interface{}, 0, len(table.ColumnsSeq()))
for _, col := range table.Columns() {
if col.MapType == schemas.ONLYFROMDB {
continue
}
if session.statement.OmitColumnMap.Contain(col.Name) {
continue
}
if len(session.statement.ColumnMap) > 0 && !session.statement.ColumnMap.Contain(col.Name) {
continue
}
if session.statement.IncrColumns.IsColExist(col.Name) {
continue
} else if session.statement.DecrColumns.IsColExist(col.Name) {
continue
} else if session.statement.ExprColumns.IsColExist(col.Name) {
continue
}
if col.IsDeleted {
arg, err := dialects.FormatColumnTime(session.engine.dialect, session.engine.DatabaseTZ, col, time.Time{})
if err != nil {
return nil, nil, err
}
args = append(args, arg)
colNames = append(colNames, col.Name)
continue
}
fieldValuePtr, err := col.ValueOf(bean)
if err != nil {
return nil, nil, err
}
fieldValue := *fieldValuePtr
if col.IsAutoIncrement && utils.IsValueZero(fieldValue) {
continue
}
// !evalphobia! set fieldValue as nil when column is nullable and zero-value
if _, ok := getFlagForColumn(session.statement.NullableMap, col); ok {
if col.Nullable && utils.IsValueZero(fieldValue) {
var nilValue *int
fieldValue = reflect.ValueOf(nilValue)
}
}
if (col.IsCreated || col.IsUpdated) && session.statement.UseAutoTime /*&& isZero(fieldValue.Interface())*/ {
// if time is non-empty, then set to auto time
val, t, err := session.engine.nowTime(col)
if err != nil {
return nil, nil, err
}
args = append(args, val)
var colName = col.Name
session.afterClosures = append(session.afterClosures, func(bean interface{}) {
col := table.GetColumn(colName)
setColumnTime(bean, col, t)
})
} else if col.IsVersion && session.statement.CheckVersion {
args = append(args, 1)
} else {
arg, err := session.statement.Value2Interface(col, fieldValue)
if err != nil {
return colNames, args, err
}
args = append(args, arg)
}
colNames = append(colNames, col.Name)
}
return colNames, args, nil
}
func (session *Session) insertMapInterface(m map[string]interface{}) (int64, error) {
if len(m) == 0 {
return 0, ErrParamsType
}
tableName := session.statement.TableName()
if len(tableName) <= 0 {
return 0, ErrTableNotFound
}
var columns = make([]string, 0, len(m))
exprs := session.statement.ExprColumns
for k := range m {
if !exprs.IsColExist(k) {
columns = append(columns, k)
}
}
sort.Strings(columns)
var args = make([]interface{}, 0, len(m))
for _, colName := range columns {
args = append(args, m[colName])
}
return session.insertMap(columns, args)
}
func (session *Session) insertMultipleMapInterface(maps []map[string]interface{}) (int64, error) {
if len(maps) <= 0 {
return 0, ErrNoElementsOnSlice
}
tableName := session.statement.TableName()
if len(tableName) <= 0 {
return 0, ErrTableNotFound
}
var columns = make([]string, 0, len(maps[0]))
exprs := session.statement.ExprColumns
for k := range maps[0] {
if !exprs.IsColExist(k) {
columns = append(columns, k)
}
}
sort.Strings(columns)
var argss = make([][]interface{}, 0, len(maps))
for _, m := range maps {
var args = make([]interface{}, 0, len(m))
for _, colName := range columns {
args = append(args, m[colName])
}
argss = append(argss, args)
}
return session.insertMultipleMap(columns, argss)
}
func (session *Session) insertMapString(m map[string]string) (int64, error) {
if len(m) == 0 {
return 0, ErrParamsType
}
tableName := session.statement.TableName()
if len(tableName) <= 0 {
return 0, ErrTableNotFound
}
var columns = make([]string, 0, len(m))
exprs := session.statement.ExprColumns
for k := range m {
if !exprs.IsColExist(k) {
columns = append(columns, k)
}
}
sort.Strings(columns)
var args = make([]interface{}, 0, len(m))
for _, colName := range columns {
args = append(args, m[colName])
}
return session.insertMap(columns, args)
}
func (session *Session) insertMultipleMapString(maps []map[string]string) (int64, error) {
if len(maps) <= 0 {
return 0, ErrNoElementsOnSlice
}
tableName := session.statement.TableName()
if len(tableName) <= 0 {
return 0, ErrTableNotFound
}
var columns = make([]string, 0, len(maps[0]))
exprs := session.statement.ExprColumns
for k := range maps[0] {
if !exprs.IsColExist(k) {
columns = append(columns, k)
}
}
sort.Strings(columns)
var argss = make([][]interface{}, 0, len(maps))
for _, m := range maps {
var args = make([]interface{}, 0, len(m))
for _, colName := range columns {
args = append(args, m[colName])
}
argss = append(argss, args)
}
return session.insertMultipleMap(columns, argss)
}
func (session *Session) insertMap(columns []string, args []interface{}) (int64, error) {
tableName := session.statement.TableName()
if len(tableName) <= 0 {
return 0, ErrTableNotFound
}
sql, args, err := session.statement.GenInsertMapSQL(columns, args)
if err != nil {
return 0, err
}
if err := session.cacheInsert(tableName); err != nil {
return 0, err
}
res, err := session.exec(sql, args...)
if err != nil {
return 0, err
}
affected, err := res.RowsAffected()
if err != nil {
return 0, err
}
return affected, nil
}
func (session *Session) insertMultipleMap(columns []string, argss [][]interface{}) (int64, error) {
tableName := session.statement.TableName()
if len(tableName) <= 0 {
return 0, ErrTableNotFound
}
sql, args, err := session.statement.GenInsertMultipleMapSQL(columns, argss)
if err != nil {
return 0, err
}
if err := session.cacheInsert(tableName); err != nil {
return 0, err
}
res, err := session.exec(sql, args...)
if err != nil {
return 0, err
}
affected, err := res.RowsAffected()
if err != nil {
return 0, err
}
return affected, nil
}