/
util.go
163 lines (145 loc) · 3.87 KB
/
util.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
package util
import (
"fmt"
"os"
"reflect"
"strings"
"github.com/go-pg/pg"
)
//variables
var (
QueryFp *os.File
Y = "y"
Yes = "yes"
)
//const in histroy
const (
historyTag = "_history"
)
//GetStructField will return struct fields
func GetStructField(model interface{}) (fields map[reflect.Value]reflect.StructField) {
refObj := reflect.ValueOf(model)
fields = make(map[reflect.Value]reflect.StructField)
if refObj.Kind() == reflect.Ptr {
refObj = refObj.Elem()
}
if refObj.IsValid() {
for i := 0; i < refObj.NumField(); i++ {
refField := refObj.Field(i)
refType := refObj.Type().Field(i)
if refType.Name[0] > 'Z' {
continue
}
if refType.Anonymous && refField.Kind() == reflect.Struct {
embdFields := GetStructField(refField.Interface())
mergeMap(fields, embdFields)
} else {
if _, exists := refType.Tag.Lookup("sql"); exists == false {
fmt.Println("No SQL tag in", refType.Name)
panic("sql tag not fround")
}
fields[refField] = refType
}
}
}
return
}
func mergeMap(a, b map[reflect.Value]reflect.StructField) {
for k, v := range b {
a[k] = v
}
}
//getSQLTag will return sql tag
func getSQLTag(refField reflect.StructField) (sqlTag string) {
sqlTag = refField.Tag.Get("sql")
sqlTag = strings.ToLower(sqlTag)
return
}
//FieldType will return field type
func FieldType(refField reflect.StructField) (fType string) {
sqlTag := getSQLTag(refField)
vals := strings.Split(sqlTag, "type:")
if len(vals) > 1 {
fType = vals[1]
fType = strings.Trim(strings.Split(fType, " ")[0], " ")
}
return
}
//RefTable will reutrn reference table
func RefTable(refField reflect.StructField) (refTable string) {
sqlTag := getSQLTag(refField)
refTag := strings.Split(sqlTag, "references")
if len(refTag) > 1 {
refTable = strings.Split(refTag[1], "(")[0]
refTable = strings.Trim(refTable, " ")
}
return
}
//GetChoice will ask user choice
func GetChoice(sql string, skipPrompt bool) (choice string) {
if skipPrompt {
choice = Yes
} else {
fmt.Printf("%v\nWant to continue (y/n):", sql)
fmt.Scan(&choice)
choice = strings.ToLower(choice)
if choice == Y {
choice = Yes
}
}
return
}
//SkipTag will check skiptag exists in model or not
func SkipTag(object interface{}) (flag bool) {
refObj := reflect.ValueOf(object).Elem()
if refObj.Kind() == reflect.Struct {
if refObj.NumField() > 0 {
if tag, exists := refObj.Type().Field(0).Tag.Lookup("history"); exists && tag == "skip" {
flag = true
}
}
}
return
}
//GetHistoryTableName will reutrn history table name
func GetHistoryTableName(tableName string) string {
return tableName + historyTag
}
//GetBeforeInsertTriggerName will return before insert trigger name
func GetBeforeInsertTriggerName(tableName string) string {
return tableName + "_before_update"
}
//GetAfterInsertTriggerName will return after insert trigger name
func GetAfterInsertTriggerName(tableName string) string {
return tableName + "_after_insert"
}
//GetAfterUpdateTriggerName will return after update trigger name
func GetAfterUpdateTriggerName(tableName string) string {
return tableName + "_after_update"
}
//GetAfterDeleteTriggerName will return after delete trigger name
func GetAfterDeleteTriggerName(tableName string) string {
return tableName + "_after_delete"
}
//IsAfterUpdateTriggerExists will check if after update triger exists
func IsAfterUpdateTriggerExists(tx *pg.Tx, tName string) (exists bool, err error) {
var count int
sql := `
SELECT count(*)
FROM information_schema.triggers
WHERE event_object_table = ?
AND trigger_name = ?
AND action_timing = 'AFTER'`
afterUpdate := GetAfterUpdateTriggerName(tName)
if _, err = tx.Query(&count, sql, tName, afterUpdate); err == nil && count > 0 {
exists = true
}
return
}
//GetStrByLen will return string till given length
func GetStrByLen(str string, n int) string {
if len(str) > n {
str = string(str[:n-1])
}
return str
}