|
| 1 | +package mysql |
| 2 | + |
| 3 | +import ( |
| 4 | + "encoding/json" |
| 5 | + "fmt" |
| 6 | + "strings" |
| 7 | + "time" |
| 8 | + |
| 9 | + "github.com/temporalio/sqlparser" |
| 10 | + "go.temporal.io/server/common/persistence/sql/sqlplugin" |
| 11 | + "go.temporal.io/server/common/persistence/visibility/store/query" |
| 12 | + "go.temporal.io/server/common/searchattribute" |
| 13 | +) |
| 14 | + |
| 15 | +var maxDatetime = time.Date(9999, 12, 31, 23, 59, 59, 0, time.UTC) |
| 16 | + |
| 17 | +type ( |
| 18 | + castExpr struct { |
| 19 | + sqlparser.Expr |
| 20 | + Value sqlparser.Expr |
| 21 | + Type *sqlparser.ConvertType |
| 22 | + } |
| 23 | + |
| 24 | + memberOfExpr struct { |
| 25 | + sqlparser.Expr |
| 26 | + Value sqlparser.Expr |
| 27 | + JSONArr sqlparser.Expr |
| 28 | + } |
| 29 | + |
| 30 | + jsonOverlapsExpr struct { |
| 31 | + sqlparser.Expr |
| 32 | + JSONDoc1 sqlparser.Expr |
| 33 | + JSONDoc2 sqlparser.Expr |
| 34 | + } |
| 35 | +) |
| 36 | + |
| 37 | +var ( |
| 38 | + convertTypeDatetime = &sqlparser.ConvertType{Type: "datetime"} |
| 39 | + convertTypeJSON = &sqlparser.ConvertType{Type: "json"} |
| 40 | +) |
| 41 | + |
| 42 | +var _ sqlparser.Expr = (*castExpr)(nil) |
| 43 | +var _ sqlparser.Expr = (*memberOfExpr)(nil) |
| 44 | +var _ sqlparser.Expr = (*jsonOverlapsExpr)(nil) |
| 45 | + |
| 46 | +func (node *castExpr) Format(buf *sqlparser.TrackedBuffer) { |
| 47 | + buf.Myprintf("cast(%v as %v)", node.Value, node.Type) |
| 48 | +} |
| 49 | + |
| 50 | +func (node *memberOfExpr) Format(buf *sqlparser.TrackedBuffer) { |
| 51 | + buf.Myprintf("%v member of (%v)", node.Value, node.JSONArr) |
| 52 | +} |
| 53 | + |
| 54 | +func (node *jsonOverlapsExpr) Format(buf *sqlparser.TrackedBuffer) { |
| 55 | + buf.Myprintf("json_overlaps(%v, %v)", node.JSONDoc1, node.JSONDoc2) |
| 56 | +} |
| 57 | + |
| 58 | +type queryConverter struct{} |
| 59 | + |
| 60 | +var _ sqlplugin.VisibilityQueryConverter = (*queryConverter)(nil) |
| 61 | + |
| 62 | +func (c *queryConverter) GetDatetimeFormat() string { |
| 63 | + return "2006-01-02 15:04:05.999999" |
| 64 | +} |
| 65 | + |
| 66 | +func (c *queryConverter) GetCoalesceCloseTimeExpr() sqlparser.Expr { |
| 67 | + return query.NewFuncExpr( |
| 68 | + "coalesce", |
| 69 | + query.CloseTimeSaColName, |
| 70 | + &castExpr{ |
| 71 | + Value: query.NewUnsafeSQLString(maxDatetime.Format(c.GetDatetimeFormat())), |
| 72 | + Type: convertTypeDatetime, |
| 73 | + }, |
| 74 | + ) |
| 75 | +} |
| 76 | + |
| 77 | +func (c *queryConverter) ConvertKeywordListComparisonExpr( |
| 78 | + operator string, |
| 79 | + col *query.SAColName, |
| 80 | + value sqlparser.Expr, |
| 81 | +) (sqlparser.Expr, error) { |
| 82 | + var negate bool |
| 83 | + var newExpr sqlparser.Expr |
| 84 | + switch operator { |
| 85 | + case sqlparser.EqualStr, sqlparser.NotEqualStr: |
| 86 | + newExpr = &memberOfExpr{ |
| 87 | + Value: value, |
| 88 | + JSONArr: col, |
| 89 | + } |
| 90 | + negate = operator == sqlparser.NotEqualStr |
| 91 | + case sqlparser.InStr, sqlparser.NotInStr: |
| 92 | + var err error |
| 93 | + newExpr, err = c.buildJSONOverlapsExpr(col, value) |
| 94 | + if err != nil { |
| 95 | + return nil, err |
| 96 | + } |
| 97 | + negate = operator == sqlparser.NotInStr |
| 98 | + default: |
| 99 | + // this should never happen since isSupportedKeywordListOperator should already fail |
| 100 | + return nil, query.NewConverterError( |
| 101 | + "%s: operator '%s' not supported for KeywordList type", |
| 102 | + query.InvalidExpressionErrMessage, |
| 103 | + operator, |
| 104 | + ) |
| 105 | + } |
| 106 | + |
| 107 | + if negate { |
| 108 | + newExpr = &sqlparser.NotExpr{Expr: newExpr} |
| 109 | + } |
| 110 | + return newExpr, nil |
| 111 | +} |
| 112 | + |
| 113 | +func (c *queryConverter) ConvertTextComparisonExpr( |
| 114 | + operator string, |
| 115 | + col *query.SAColName, |
| 116 | + value sqlparser.Expr, |
| 117 | +) (sqlparser.Expr, error) { |
| 118 | + // build the following expression: |
| 119 | + // `match ({col}) against ({value} in natural language mode)` |
| 120 | + var newExpr sqlparser.Expr = &sqlparser.MatchExpr{ |
| 121 | + Columns: []sqlparser.SelectExpr{&sqlparser.AliasedExpr{Expr: col}}, |
| 122 | + Expr: value, |
| 123 | + Option: sqlparser.NaturalLanguageModeStr, |
| 124 | + } |
| 125 | + if operator == sqlparser.NotEqualStr { |
| 126 | + newExpr = &sqlparser.NotExpr{Expr: newExpr} |
| 127 | + } |
| 128 | + return newExpr, nil |
| 129 | +} |
| 130 | + |
| 131 | +func (c *queryConverter) BuildSelectStmt( |
| 132 | + queryParams *query.QueryParams[sqlparser.Expr], |
| 133 | + pageSize int, |
| 134 | + token *sqlplugin.VisibilityPageToken, |
| 135 | +) (string, []any) { |
| 136 | + var whereClauses []string |
| 137 | + var queryArgs []any |
| 138 | + |
| 139 | + queryString := sqlparser.String(queryParams.QueryExpr) |
| 140 | + whereClauses = append(whereClauses, queryString) |
| 141 | + |
| 142 | + if token != nil { |
| 143 | + whereClauses = append( |
| 144 | + whereClauses, |
| 145 | + fmt.Sprintf( |
| 146 | + "((%s = ? AND %s = ? AND %s > ?) OR (%s = ? AND %s < ?) OR %s < ?)", |
| 147 | + sqlparser.String(c.GetCoalesceCloseTimeExpr()), |
| 148 | + searchattribute.GetSqlDbColName(searchattribute.StartTime), |
| 149 | + searchattribute.GetSqlDbColName(searchattribute.RunID), |
| 150 | + sqlparser.String(c.GetCoalesceCloseTimeExpr()), |
| 151 | + searchattribute.GetSqlDbColName(searchattribute.StartTime), |
| 152 | + sqlparser.String(c.GetCoalesceCloseTimeExpr()), |
| 153 | + ), |
| 154 | + ) |
| 155 | + queryArgs = append( |
| 156 | + queryArgs, |
| 157 | + token.CloseTime, |
| 158 | + token.StartTime, |
| 159 | + token.RunID, |
| 160 | + token.CloseTime, |
| 161 | + token.StartTime, |
| 162 | + token.CloseTime, |
| 163 | + ) |
| 164 | + } |
| 165 | + |
| 166 | + dbFields := make([]string, len(sqlplugin.DbFields)) |
| 167 | + for i, field := range sqlplugin.DbFields { |
| 168 | + dbFields[i] = "ev." + field |
| 169 | + } |
| 170 | + |
| 171 | + stmt := fmt.Sprintf( |
| 172 | + `SELECT %s |
| 173 | + FROM executions_visibility ev |
| 174 | + LEFT JOIN custom_search_attributes |
| 175 | + USING (%s, %s) |
| 176 | + WHERE %s |
| 177 | + ORDER BY %s DESC, %s DESC, %s |
| 178 | + LIMIT ?`, |
| 179 | + strings.Join(dbFields, ", "), |
| 180 | + searchattribute.GetSqlDbColName(searchattribute.NamespaceID), |
| 181 | + searchattribute.GetSqlDbColName(searchattribute.RunID), |
| 182 | + strings.Join(whereClauses, " AND "), |
| 183 | + sqlparser.String(c.GetCoalesceCloseTimeExpr()), |
| 184 | + searchattribute.GetSqlDbColName(searchattribute.StartTime), |
| 185 | + searchattribute.GetSqlDbColName(searchattribute.RunID), |
| 186 | + ) |
| 187 | + queryArgs = append(queryArgs, pageSize) |
| 188 | + |
| 189 | + return stmt, queryArgs |
| 190 | +} |
| 191 | + |
| 192 | +func (c *queryConverter) BuildCountStmt( |
| 193 | + queryParams *query.QueryParams[sqlparser.Expr], |
| 194 | +) (string, []any) { |
| 195 | + groupBy := make([]string, 0, len(queryParams.GroupBy)+1) |
| 196 | + for _, field := range queryParams.GroupBy { |
| 197 | + groupBy = append(groupBy, searchattribute.GetSqlDbColName(field.FieldName)) |
| 198 | + } |
| 199 | + |
| 200 | + groupByClause := "" |
| 201 | + if len(queryParams.GroupBy) > 0 { |
| 202 | + groupByClause = fmt.Sprintf("GROUP BY %s", strings.Join(groupBy, ", ")) |
| 203 | + } |
| 204 | + |
| 205 | + return fmt.Sprintf( |
| 206 | + `SELECT %s |
| 207 | + FROM executions_visibility ev |
| 208 | + LEFT JOIN custom_search_attributes |
| 209 | + USING (%s, %s) |
| 210 | + WHERE %s |
| 211 | + %s`, |
| 212 | + strings.Join(append(groupBy, "COUNT(*)"), ", "), |
| 213 | + searchattribute.GetSqlDbColName(searchattribute.NamespaceID), |
| 214 | + searchattribute.GetSqlDbColName(searchattribute.RunID), |
| 215 | + sqlparser.String(queryParams.QueryExpr), |
| 216 | + groupByClause, |
| 217 | + ), nil |
| 218 | +} |
| 219 | + |
| 220 | +func (c *queryConverter) buildJSONOverlapsExpr( |
| 221 | + col *query.SAColName, |
| 222 | + value sqlparser.Expr, |
| 223 | +) (*jsonOverlapsExpr, error) { |
| 224 | + valTuple, isValTuple := value.(sqlparser.ValTuple) |
| 225 | + if !isValTuple { |
| 226 | + return nil, query.NewConverterError( |
| 227 | + "%s: unexpected value type (expected tuple of strings, got %s)", |
| 228 | + query.InvalidExpressionErrMessage, |
| 229 | + sqlparser.String(value), |
| 230 | + ) |
| 231 | + } |
| 232 | + values, err := query.GetUnsafeStringTupleValues(valTuple) |
| 233 | + if err != nil { |
| 234 | + return nil, err |
| 235 | + } |
| 236 | + jsonValue, err := json.Marshal(values) |
| 237 | + if err != nil { |
| 238 | + return nil, err |
| 239 | + } |
| 240 | + return &jsonOverlapsExpr{ |
| 241 | + JSONDoc1: col, |
| 242 | + JSONDoc2: &castExpr{ |
| 243 | + Value: query.NewUnsafeSQLString(string(jsonValue)), |
| 244 | + Type: convertTypeJSON, |
| 245 | + }, |
| 246 | + }, nil |
| 247 | +} |
0 commit comments