1
2
3
4
5 package walk
6
7 import (
8 "cmp"
9 "fmt"
10 "go/constant"
11 "go/token"
12 "math"
13 "math/bits"
14 "slices"
15 "sort"
16 "strings"
17
18 "cmd/compile/internal/base"
19 "cmd/compile/internal/ir"
20 "cmd/compile/internal/objw"
21 "cmd/compile/internal/reflectdata"
22 "cmd/compile/internal/rttype"
23 "cmd/compile/internal/ssagen"
24 "cmd/compile/internal/typecheck"
25 "cmd/compile/internal/types"
26 "cmd/internal/obj"
27 "cmd/internal/src"
28 )
29
30
31 func walkSwitch(sw *ir.SwitchStmt) {
32
33 if sw.Walked() {
34 return
35 }
36 sw.SetWalked(true)
37
38 if sw.Tag != nil && sw.Tag.Op() == ir.OTYPESW {
39 walkSwitchType(sw)
40 } else {
41 walkSwitchExpr(sw)
42 }
43 }
44
45
46
47 func walkSwitchExpr(sw *ir.SwitchStmt) {
48 lno := ir.SetPos(sw)
49
50 cond := sw.Tag
51 sw.Tag = nil
52
53
54 if cond == nil {
55 cond = ir.NewBool(base.Pos, true)
56 cond = typecheck.Expr(cond)
57 cond = typecheck.DefaultLit(cond, nil)
58 }
59
60
61
62
63
64
65
66
67 if cond.Op() == ir.OBYTES2STR && allCaseExprsAreSideEffectFree(sw) {
68 cond := cond.(*ir.ConvExpr)
69 cond.SetOp(ir.OBYTES2STRTMP)
70 }
71
72 cond = walkExpr(cond, sw.PtrInit())
73 if cond.Op() != ir.OLITERAL && cond.Op() != ir.ONIL {
74 cond = copyExpr(cond, cond.Type(), &sw.Compiled)
75 }
76
77 base.Pos = lno
78
79 tryLookupTable(sw, cond)
80
81 s := exprSwitch{
82 pos: lno,
83 exprname: cond,
84 }
85
86 var defaultGoto ir.Node
87 var body ir.Nodes
88 for _, ncase := range sw.Cases {
89 label := typecheck.AutoLabel(".s")
90 jmp := ir.NewBranchStmt(ncase.Pos(), ir.OGOTO, label)
91
92
93 if len(ncase.List) == 0 {
94 if defaultGoto != nil {
95 base.Fatalf("duplicate default case not detected during typechecking")
96 }
97 defaultGoto = jmp
98 }
99
100 for i, n1 := range ncase.List {
101 var rtype ir.Node
102 if i < len(ncase.RTypes) {
103 rtype = ncase.RTypes[i]
104 }
105 s.Add(ncase.Pos(), n1, rtype, jmp)
106 }
107
108
109 body.Append(ir.NewLabelStmt(ncase.Pos(), label))
110 body.Append(ncase.Body...)
111 if fall, pos := endsInFallthrough(ncase.Body); !fall {
112 br := ir.NewBranchStmt(base.Pos, ir.OBREAK, nil)
113 br.SetPos(pos)
114 body.Append(br)
115 }
116 }
117 sw.Cases = nil
118
119 if defaultGoto == nil {
120 br := ir.NewBranchStmt(base.Pos, ir.OBREAK, nil)
121 br.SetPos(br.Pos().WithNotStmt())
122 defaultGoto = br
123 }
124
125 s.Emit(&sw.Compiled)
126 sw.Compiled.Append(defaultGoto)
127 sw.Compiled.Append(body.Take()...)
128 walkStmtList(sw.Compiled)
129 }
130
131
132 type exprSwitch struct {
133 pos src.XPos
134 exprname ir.Node
135
136 done ir.Nodes
137 clauses []exprClause
138 }
139
140 type exprClause struct {
141 pos src.XPos
142 lo, hi ir.Node
143 rtype ir.Node
144 jmp ir.Node
145 }
146
147 func (s *exprSwitch) Add(pos src.XPos, expr, rtype, jmp ir.Node) {
148 c := exprClause{pos: pos, lo: expr, hi: expr, rtype: rtype, jmp: jmp}
149 if types.IsOrdered[s.exprname.Type().Kind()] && expr.Op() == ir.OLITERAL {
150 s.clauses = append(s.clauses, c)
151 return
152 }
153
154 s.flush()
155 s.clauses = append(s.clauses, c)
156 s.flush()
157 }
158
159 func (s *exprSwitch) Emit(out *ir.Nodes) {
160 s.flush()
161 out.Append(s.done.Take()...)
162 }
163
164 func (s *exprSwitch) flush() {
165 cc := s.clauses
166 s.clauses = nil
167 if len(cc) == 0 {
168 return
169 }
170
171
172
173
174
175
176 if s.exprname.Type().IsString() && len(cc) >= 2 {
177
178
179
180
181 slices.SortFunc(cc, func(a, b exprClause) int {
182 si := ir.StringVal(a.lo)
183 sj := ir.StringVal(b.lo)
184 if len(si) != len(sj) {
185 return cmp.Compare(len(si), len(sj))
186 }
187 return strings.Compare(si, sj)
188 })
189
190
191
192 runLen := func(run []exprClause) int64 { return int64(len(ir.StringVal(run[0].lo))) }
193
194
195 var runs [][]exprClause
196 start := 0
197 for i := 1; i < len(cc); i++ {
198 if runLen(cc[start:]) != runLen(cc[i:]) {
199 runs = append(runs, cc[start:i])
200 start = i
201 }
202 }
203 runs = append(runs, cc[start:])
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226 outerLabel := typecheck.AutoLabel(".s")
227 endLabel := typecheck.AutoLabel(".s")
228
229
230 s.done.Append(ir.NewBranchStmt(s.pos, ir.OGOTO, outerLabel))
231
232 var outer exprSwitch
233 outer.exprname = ir.NewUnaryExpr(s.pos, ir.OLEN, s.exprname)
234 outer.exprname.SetType(types.Types[types.TINT])
235
236 for _, run := range runs {
237
238 label := typecheck.AutoLabel(".s")
239
240
241 pos := run[0].pos
242 s.done.Append(ir.NewLabelStmt(pos, label))
243 stringSearch(s.exprname, run, &s.done)
244 s.done.Append(ir.NewBranchStmt(pos, ir.OGOTO, endLabel))
245
246
247 cas := ir.NewInt(pos, runLen(run))
248 jmp := ir.NewBranchStmt(pos, ir.OGOTO, label)
249 outer.Add(pos, cas, nil, jmp)
250 }
251 s.done.Append(ir.NewLabelStmt(s.pos, outerLabel))
252 outer.Emit(&s.done)
253 s.done.Append(ir.NewLabelStmt(s.pos, endLabel))
254 return
255 }
256
257 sort.Slice(cc, func(i, j int) bool {
258 return constant.Compare(cc[i].lo.Val(), token.LSS, cc[j].lo.Val())
259 })
260
261
262 if s.exprname.Type().IsInteger() {
263 consecutive := func(last, next constant.Value) bool {
264 delta := constant.BinaryOp(next, token.SUB, last)
265 return constant.Compare(delta, token.EQL, constant.MakeInt64(1))
266 }
267
268 merged := cc[:1]
269 for _, c := range cc[1:] {
270 last := &merged[len(merged)-1]
271 if last.jmp == c.jmp && consecutive(last.hi.Val(), c.lo.Val()) {
272 last.hi = c.lo
273 } else {
274 merged = append(merged, c)
275 }
276 }
277 cc = merged
278 }
279
280 s.search(cc, &s.done)
281 }
282
283 func (s *exprSwitch) search(cc []exprClause, out *ir.Nodes) {
284 if s.tryJumpTable(cc, out) {
285 return
286 }
287 binarySearch(len(cc), out,
288 func(i int) ir.Node {
289 return ir.NewBinaryExpr(base.Pos, ir.OLE, s.exprname, cc[i-1].hi)
290 },
291 func(i int, nif *ir.IfStmt) {
292 c := &cc[i]
293 nif.Cond = c.test(s.exprname)
294 nif.Body = []ir.Node{c.jmp}
295 },
296 )
297 }
298
299
300 func (s *exprSwitch) tryJumpTable(cc []exprClause, out *ir.Nodes) bool {
301 const minCases = 8
302 const minDensity = 4
303
304 if base.Flag.N != 0 || !ssagen.Arch.LinkArch.CanJumpTable || base.Ctxt.Retpoline {
305 return false
306 }
307 if len(cc) < minCases {
308 return false
309 }
310 if cc[0].lo.Val().Kind() != constant.Int {
311 return false
312 }
313 if s.exprname.Type().Size() > int64(types.PtrSize) {
314 return false
315 }
316 min := cc[0].lo.Val()
317 max := cc[len(cc)-1].hi.Val()
318 width := constant.BinaryOp(constant.BinaryOp(max, token.SUB, min), token.ADD, constant.MakeInt64(1))
319 limit := constant.MakeInt64(int64(len(cc)) * minDensity)
320 if constant.Compare(width, token.GTR, limit) {
321
322
323 return false
324 }
325 jt := ir.NewJumpTableStmt(base.Pos, s.exprname)
326 for _, c := range cc {
327 jmp := c.jmp.(*ir.BranchStmt)
328 if jmp.Op() != ir.OGOTO || jmp.Label == nil {
329 panic("bad switch case body")
330 }
331 for i := c.lo.Val(); constant.Compare(i, token.LEQ, c.hi.Val()); i = constant.BinaryOp(i, token.ADD, constant.MakeInt64(1)) {
332 jt.Cases = append(jt.Cases, i)
333 jt.Targets = append(jt.Targets, jmp.Label)
334 }
335 }
336 out.Append(jt)
337 return true
338 }
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377 func tryLookupTable(sw *ir.SwitchStmt, cond ir.Node) {
378 const minCases = 4
379
380 if base.Flag.N != 0 {
381 return
382 }
383 if !cond.Type().IsInteger() {
384 return
385 }
386 if cond.Type().Size() > int64(types.PtrSize) {
387 return
388 }
389
390
391
392
393 for _, ncase := range sw.Cases {
394 if fall, _ := endsInFallthrough(ncase.Body); fall {
395 return
396 }
397 }
398
399 fn := ir.CurFunc
400 if fn == nil || fn.Type().NumResults() != 1 {
401 return
402 }
403 resultType := fn.Type().Results()[0].Type
404 if !resultType.IsInteger() {
405
406
407 return
408 }
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429 constSet := make(map[int64]constant.Value)
430 constCaseSet := make(map[int]bool)
431 var defaultVal constant.Value
432 var hasConstDefault bool
433 excludeSet := make(map[int64]bool)
434 minVal, maxVal := int64(math.MaxInt64), int64(math.MinInt64)
435
436 for i, ncase := range sw.Cases {
437 if len(ncase.List) == 0 {
438
439 if isConstIntReturn(ncase) {
440 hasConstDefault = true
441 defaultVal = ncase.Body[0].(*ir.ReturnStmt).Results[0].Val()
442 }
443 continue
444 }
445
446 vals, ok := constIntCaseVals(ncase)
447 if !ok {
448
449
450
451
452
453
454
455
456
457 return
458 }
459
460 if !isConstIntReturn(ncase) {
461
462
463
464
465 for _, v := range vals {
466 excludeSet[v] = true
467 }
468 continue
469 }
470
471 retVal := ncase.Body[0].(*ir.ReturnStmt).Results[0].Val()
472 for _, v := range vals {
473 constSet[v] = retVal
474 minVal = min(minVal, v)
475 maxVal = max(maxVal, v)
476 }
477 constCaseSet[i] = true
478 }
479
480 if len(constSet) < minCases {
481 return
482 }
483
484 tableSize := maxVal - minVal + 1
485 if tableSize <= 0 || !isSwitchDense(int64(len(constSet)), tableSize) {
486 return
487 }
488
489
490
491 tabType := types.NewArray(resultType, tableSize)
492 tabName := readonlystaticname(tabType)
493 lsym := tabName.Linksym()
494 elemSize := int(resultType.Size())
495 maxBitmaskSize := int64(types.PtrSize * 8)
496
497 needMask := false
498 var bitmask uint64
499 validSlots := make([]bool, tableSize)
500 for i := range tableSize {
501 caseVal := minVal + i
502 var v int64
503 switch {
504 case excludeSet[caseVal]:
505
506 needMask = true
507 case constSet[caseVal] != nil:
508 v = ir.IntVal(resultType, constSet[caseVal])
509 validSlots[i] = true
510 bitmask |= 1 << uint(i)
511 case hasConstDefault:
512
513 v = ir.IntVal(resultType, defaultVal)
514 validSlots[i] = true
515 bitmask |= 1 << uint(i)
516 default:
517
518 needMask = true
519 }
520 lsym.WriteInt(base.Ctxt, i*int64(elemSize), elemSize, v)
521 }
522
523
524
525
526
527 var maskName *ir.Name
528 useBitmask := needMask && tableSize <= maxBitmaskSize
529 if needMask && !useBitmask {
530 maskType := types.NewArray(types.Types[types.TUINT8], tableSize)
531 maskName = readonlystaticname(maskType)
532 maskSym := maskName.Linksym()
533 for i := range tableSize {
534 var v uint8
535 if validSlots[i] {
536 v = 1
537 }
538 maskSym.WriteInt(base.Ctxt, i, 1, int64(v))
539 }
540 }
541
542
543
544
545 pos := sw.Pos()
546
547
548 intType := types.Types[types.TINT]
549 wideCond := typecheck.Conv(cond, intType)
550
551
552 var idx ir.Node
553 if minVal != 0 {
554 minLit := ir.NewBasicLit(pos, intType, constant.MakeInt64(minVal))
555 idx = typecheck.Expr(ir.NewBinaryExpr(pos, ir.OSUB, wideCond, minLit))
556 } else {
557 idx = wideCond
558 }
559
560
561
562 uintType := types.Types[types.TUINT]
563 uidx := typecheck.Conv(idx, uintType)
564 uidx = copyExpr(uidx, uintType, &sw.Compiled)
565
566
567 rangeLit := ir.NewBasicLit(pos, uintType, constant.MakeUint64(uint64(maxVal-minVal)))
568 boundsCheck := typecheck.Expr(ir.NewBinaryExpr(pos, ir.OLE, uidx, rangeLit))
569 boundsCheck = typecheck.DefaultLit(boundsCheck, nil)
570
571
572 lookup := ir.NewIndexExpr(pos, tabName, uidx)
573 lookup.SetBounded(true)
574 lookup = typecheck.Expr(lookup).(*ir.IndexExpr)
575
576 retStmt := ir.NewReturnStmt(pos, []ir.Node{lookup})
577
578 var ifBody []ir.Node
579 if needMask {
580 var maskCheck ir.Node
581 if useBitmask {
582
583
584 bitmaskType := types.Types[types.TUINTPTR]
585 bitmaskLit := ir.NewBasicLit(pos, bitmaskType, constant.MakeUint64(bitmask))
586 shifted := typecheck.Expr(ir.NewBinaryExpr(pos, ir.ORSH, bitmaskLit, uidx))
587 one := ir.NewBasicLit(pos, bitmaskType, constant.MakeUint64(1))
588 masked := typecheck.Expr(ir.NewBinaryExpr(pos, ir.OAND, shifted, one))
589 zero := ir.NewBasicLit(pos, bitmaskType, constant.MakeUint64(0))
590 maskCheck = typecheck.Expr(ir.NewBinaryExpr(pos, ir.ONE, masked, zero))
591 } else {
592
593 maskLookup := ir.NewIndexExpr(pos, maskName, uidx)
594 maskLookup.SetBounded(true)
595 maskLookup = typecheck.Expr(maskLookup).(*ir.IndexExpr)
596 zero := ir.NewBasicLit(pos, types.Types[types.TUINT8], constant.MakeInt64(0))
597 maskCheck = typecheck.Expr(ir.NewBinaryExpr(pos, ir.ONE, maskLookup, zero))
598 }
599 maskCheck = typecheck.DefaultLit(maskCheck, nil)
600
601 innerIf := ir.NewIfStmt(pos, maskCheck, []ir.Node{retStmt}, nil)
602 ifBody = []ir.Node{innerIf}
603 } else {
604 ifBody = []ir.Node{retStmt}
605 }
606
607 outerIf := ir.NewIfStmt(pos, boundsCheck, ifBody, nil)
608 sw.Compiled.Append(outerIf)
609
610
611
612 newCases := make([]*ir.CaseClause, 0, len(sw.Cases)-len(constCaseSet))
613 for i, ncase := range sw.Cases {
614 if !constCaseSet[i] {
615 newCases = append(newCases, ncase)
616 }
617 }
618 sw.Cases = newCases
619 }
620
621
622
623
624
625 func isSwitchDense(numCases, tableSize int64) bool {
626 const minDensity = 40
627 if tableSize >= math.MaxInt64/100 {
628 return false
629 }
630 return numCases*100 >= tableSize*minDensity
631 }
632
633
634
635 func isConstIntReturn(ncase *ir.CaseClause) bool {
636 if len(ncase.Body) != 1 {
637 return false
638 }
639 ret, ok := ncase.Body[0].(*ir.ReturnStmt)
640 if !ok || len(ret.Results) != 1 {
641 return false
642 }
643 r := ret.Results[0]
644 return r.Op() == ir.OLITERAL && r.Val().Kind() == constant.Int
645 }
646
647
648
649
650 func constIntCaseVals(ncase *ir.CaseClause) (vals []int64, ok bool) {
651 for _, n1 := range ncase.List {
652 if n1.Op() != ir.OLITERAL || n1.Val().Kind() != constant.Int {
653 return nil, false
654 }
655 v, fit := constant.Int64Val(n1.Val())
656 if !fit {
657 return nil, false
658 }
659 vals = append(vals, v)
660 }
661 return vals, true
662 }
663
664 func (c *exprClause) test(exprname ir.Node) ir.Node {
665
666 if c.hi != c.lo {
667 low := ir.NewBinaryExpr(c.pos, ir.OGE, exprname, c.lo)
668 high := ir.NewBinaryExpr(c.pos, ir.OLE, exprname, c.hi)
669 return ir.NewLogicalExpr(c.pos, ir.OANDAND, low, high)
670 }
671
672
673 if ir.IsConst(exprname, constant.Bool) && !c.lo.Type().IsInterface() {
674 if ir.BoolVal(exprname) {
675 return c.lo
676 } else {
677 return ir.NewUnaryExpr(c.pos, ir.ONOT, c.lo)
678 }
679 }
680
681 n := ir.NewBinaryExpr(c.pos, ir.OEQ, exprname, c.lo)
682 n.RType = c.rtype
683 return n
684 }
685
686 func allCaseExprsAreSideEffectFree(sw *ir.SwitchStmt) bool {
687
688
689
690
691
692
693
694 for _, ncase := range sw.Cases {
695 for _, v := range ncase.List {
696 if v.Op() != ir.OLITERAL {
697 return false
698 }
699 }
700 }
701 return true
702 }
703
704
705 func endsInFallthrough(stmts []ir.Node) (bool, src.XPos) {
706 if len(stmts) == 0 {
707 return false, src.NoXPos
708 }
709 i := len(stmts) - 1
710 return stmts[i].Op() == ir.OFALL, stmts[i].Pos()
711 }
712
713
714
715 func walkSwitchType(sw *ir.SwitchStmt) {
716 var s typeSwitch
717 s.srcName = sw.Tag.(*ir.TypeSwitchGuard).X
718 s.srcName = walkExpr(s.srcName, sw.PtrInit())
719 s.srcName = copyExpr(s.srcName, s.srcName.Type(), &sw.Compiled)
720 s.okName = typecheck.TempAt(base.Pos, ir.CurFunc, types.Types[types.TBOOL])
721 s.itabName = typecheck.TempAt(base.Pos, ir.CurFunc, types.Types[types.TUINT8].PtrTo())
722
723
724
725
726 srcItab := ir.NewUnaryExpr(base.Pos, ir.OITAB, s.srcName)
727 srcData := ir.NewUnaryExpr(base.Pos, ir.OIDATA, s.srcName)
728 srcData.SetType(types.Types[types.TUINT8].PtrTo())
729 srcData.SetTypecheck(1)
730
731
732
733
734
735
736
737 ifNil := ir.NewIfStmt(base.Pos, nil, nil, nil)
738 ifNil.Cond = ir.NewBinaryExpr(base.Pos, ir.OEQ, srcItab, typecheck.NodNil())
739 base.Pos = base.Pos.WithNotStmt()
740 ifNil.Cond = typecheck.Expr(ifNil.Cond)
741 ifNil.Cond = typecheck.DefaultLit(ifNil.Cond, nil)
742
743 sw.Compiled.Append(ifNil)
744
745
746 dotHash := typeHashFieldOf(base.Pos, srcItab)
747 s.hashName = copyExpr(dotHash, dotHash.Type(), &sw.Compiled)
748
749
750 labels := make([]*types.Sym, len(sw.Cases))
751 for i := range sw.Cases {
752 labels[i] = typecheck.AutoLabel(".s")
753 }
754
755
756 br := ir.NewBranchStmt(base.Pos, ir.OBREAK, nil)
757
758
759
760
761 type oneCase struct {
762 pos src.XPos
763 jmp ir.Node
764
765
766
767
768
769 typ ir.Node
770
771
772
773
774
775 val ir.Node
776 idx int
777 }
778 var cases []oneCase
779 var defaultGoto, nilGoto ir.Node
780 for i, ncase := range sw.Cases {
781 jmp := ir.NewBranchStmt(ncase.Pos(), ir.OGOTO, labels[i])
782 if len(ncase.List) == 0 {
783 if defaultGoto != nil {
784 base.Fatalf("duplicate default case not detected during typechecking")
785 }
786 defaultGoto = jmp
787 }
788 for _, n1 := range ncase.List {
789 if ir.IsNil(n1) {
790 if nilGoto != nil {
791 base.Fatalf("duplicate nil case not detected during typechecking")
792 }
793 nilGoto = jmp
794 continue
795 }
796 idx := -1
797 var val ir.Node
798
799 if len(ncase.List) == 1 && ncase.List[0].Op() == ir.ODYNAMICTYPE && ncase.Var != nil {
800 val = typecheck.TempAt(ncase.Pos(), ir.CurFunc, ncase.Var.Type())
801 idx = i
802 }
803 cases = append(cases, oneCase{
804 pos: ncase.Pos(),
805 typ: n1,
806 jmp: jmp,
807 val: val,
808 idx: idx,
809 })
810 }
811 }
812 if defaultGoto == nil {
813 defaultGoto = br
814 }
815 if nilGoto == nil {
816 nilGoto = defaultGoto
817 }
818 ifNil.Body = []ir.Node{nilGoto}
819
820
821 var concreteCases []oneCase
822 var interfaceCases []oneCase
823 flush := func() {
824
825
826
827
828 if len(concreteCases) > 0 {
829 var clauses []typeClause
830 for _, c := range concreteCases {
831 as := ir.NewAssignListStmt(c.pos, ir.OAS2,
832 []ir.Node{ir.BlankNode, s.okName},
833 []ir.Node{ir.NewTypeAssertExpr(c.pos, s.srcName, c.typ.Type())})
834 nif := ir.NewIfStmt(c.pos, s.okName, []ir.Node{c.jmp}, nil)
835 clauses = append(clauses, typeClause{
836 hash: types.TypeHash(c.typ.Type()),
837 body: []ir.Node{typecheck.Stmt(as), typecheck.Stmt(nif)},
838 })
839 }
840 s.flush(clauses, &sw.Compiled)
841 concreteCases = concreteCases[:0]
842 }
843
844
845
846
847 var anyGoto ir.Node
848 if len(interfaceCases) > 0 && interfaceCases[len(interfaceCases)-1].typ.Type().IsEmptyInterface() {
849 anyGoto = interfaceCases[len(interfaceCases)-1].jmp
850 interfaceCases = interfaceCases[:len(interfaceCases)-1]
851 }
852
853
854 if len(interfaceCases) > 0 {
855
856
857 lsym := types.LocalPkg.Lookup(fmt.Sprintf(".interfaceSwitch.%d", interfaceSwitchGen)).LinksymABI(obj.ABI0)
858 interfaceSwitchGen++
859 c := rttype.NewCursor(lsym, 0, rttype.InterfaceSwitch)
860 c.Field("Cache").WritePtr(typecheck.LookupRuntimeVar("emptyInterfaceSwitchCache"))
861 c.Field("NCases").WriteInt(int64(len(interfaceCases)))
862 array, sizeDelta := c.Field("Cases").ModifyArray(len(interfaceCases))
863 for i, c := range interfaceCases {
864 array.Elem(i).WritePtr(reflectdata.TypeLinksym(c.typ.Type()))
865 }
866 objw.Global(lsym, int32(rttype.InterfaceSwitch.Size()+sizeDelta), obj.LOCAL)
867
868
869
870 lsym.Gotype = reflectdata.TypeLinksym(rttype.InterfaceSwitch)
871
872
873
874 var typeArg ir.Node
875 if s.srcName.Type().IsEmptyInterface() {
876 typeArg = ir.NewConvExpr(base.Pos, ir.OCONVNOP, types.Types[types.TUINT8].PtrTo(), srcItab)
877 } else {
878 typeArg = itabType(srcItab)
879 }
880 caseVar := typecheck.TempAt(base.Pos, ir.CurFunc, types.Types[types.TINT])
881 isw := ir.NewInterfaceSwitchStmt(base.Pos, caseVar, s.itabName, typeArg, dotHash, lsym)
882 sw.Compiled.Append(isw)
883
884
885 var newCases []*ir.CaseClause
886 for i, c := range interfaceCases {
887 newCases = append(newCases, &ir.CaseClause{
888 List: []ir.Node{ir.NewInt(base.Pos, int64(i))},
889 Body: []ir.Node{c.jmp},
890 })
891 }
892
893 sw2 := ir.NewSwitchStmt(base.Pos, caseVar, newCases)
894 sw.Compiled.Append(typecheck.Stmt(sw2))
895 interfaceCases = interfaceCases[:0]
896 }
897
898 if anyGoto != nil {
899
900
901 sw.Compiled.Append(anyGoto)
902 }
903 }
904 caseLoop:
905 for _, c := range cases {
906 if c.typ.Op() == ir.ODYNAMICTYPE {
907 flush()
908 dt := c.typ.(*ir.DynamicType)
909 dot := ir.NewDynamicTypeAssertExpr(c.pos, ir.ODYNAMICDOTTYPE, s.srcName, dt.RType)
910 dot.ITab = dt.ITab
911 dot.SetType(c.typ.Type())
912 dot.SetTypecheck(1)
913
914 as := ir.NewAssignListStmt(c.pos, ir.OAS2, nil, nil)
915 as.Lhs = []ir.Node{ir.BlankNode, s.okName}
916 if c.val != nil {
917 as.Lhs[0] = c.val
918 }
919 as.Rhs = []ir.Node{dot}
920 typecheck.Stmt(as)
921
922 nif := ir.NewIfStmt(c.pos, s.okName, []ir.Node{c.jmp}, nil)
923 sw.Compiled.Append(as, nif)
924 continue
925 }
926
927
928
929
930
931 for _, ic := range interfaceCases {
932
933
934 if typecheck.Implements(c.typ.Type(), ic.typ.Type()) {
935 continue caseLoop
936 }
937
938
939
940
941
942
943
944
945 }
946
947 if c.typ.Type().IsInterface() {
948 interfaceCases = append(interfaceCases, c)
949 } else {
950 concreteCases = append(concreteCases, c)
951 }
952 }
953 flush()
954
955 sw.Compiled.Append(defaultGoto)
956
957
958 for i, ncase := range sw.Cases {
959 sw.Compiled.Append(ir.NewLabelStmt(ncase.Pos(), labels[i]))
960 if caseVar := ncase.Var; caseVar != nil {
961 val := s.srcName
962 if len(ncase.List) == 1 {
963
964 if ncase.List[0].Op() == ir.OTYPE {
965 t := ncase.List[0].Type()
966 if t.IsInterface() {
967
968
969 if t.IsEmptyInterface() {
970 var typ ir.Node
971 if s.srcName.Type().IsEmptyInterface() {
972
973 typ = srcItab
974 } else {
975
976 typ = itabType(srcItab)
977 typ.SetPos(ncase.Pos())
978 }
979 val = ir.NewBinaryExpr(ncase.Pos(), ir.OMAKEFACE, typ, srcData)
980 } else {
981
982 val = ir.NewBinaryExpr(ncase.Pos(), ir.OMAKEFACE, s.itabName, srcData)
983 }
984 } else {
985
986 val = ifaceData(ncase.Pos(), s.srcName, t)
987 }
988 } else if ncase.List[0].Op() == ir.ODYNAMICTYPE {
989 var found bool
990 for _, c := range cases {
991 if c.idx == i {
992 val = c.val
993 found = val != nil
994 break
995 }
996 }
997
998 if !found {
999 base.Fatalf("an error occurred when processing type switch case %v", ncase.List[0])
1000 }
1001 } else if ir.IsNil(ncase.List[0]) {
1002 } else {
1003 base.Fatalf("unhandled type switch case %v", ncase.List[0])
1004 }
1005 val.SetType(caseVar.Type())
1006 val.SetTypecheck(1)
1007 }
1008 l := []ir.Node{
1009 ir.NewDecl(ncase.Pos(), ir.ODCL, caseVar),
1010 ir.NewAssignStmt(ncase.Pos(), caseVar, val),
1011 }
1012 typecheck.Stmts(l)
1013 sw.Compiled.Append(l...)
1014 }
1015 sw.Compiled.Append(ncase.Body...)
1016 sw.Compiled.Append(br)
1017 }
1018
1019 walkStmtList(sw.Compiled)
1020 sw.Tag = nil
1021 sw.Cases = nil
1022 }
1023
1024 var interfaceSwitchGen int
1025
1026
1027
1028
1029 func typeHashFieldOf(pos src.XPos, itab *ir.UnaryExpr) *ir.SelectorExpr {
1030 if itab.Op() != ir.OITAB {
1031 base.Fatalf("expected OITAB, got %v", itab.Op())
1032 }
1033 var hashField *types.Field
1034 if itab.X.Type().IsEmptyInterface() {
1035
1036 if rtypeHashField == nil {
1037 rtypeHashField = runtimeField("hash", rttype.Type.OffsetOf("Hash"), types.Types[types.TUINT32])
1038 }
1039 hashField = rtypeHashField
1040 } else {
1041
1042 if itabHashField == nil {
1043 itabHashField = runtimeField("hash", rttype.ITab.OffsetOf("Hash"), types.Types[types.TUINT32])
1044 }
1045 hashField = itabHashField
1046 }
1047 return boundedDotPtr(pos, itab, hashField)
1048 }
1049
1050 var rtypeHashField, itabHashField *types.Field
1051
1052
1053 type typeSwitch struct {
1054
1055 srcName ir.Node
1056 hashName ir.Node
1057 okName ir.Node
1058 itabName ir.Node
1059 }
1060
1061 type typeClause struct {
1062 hash uint32
1063 body ir.Nodes
1064 }
1065
1066 func (s *typeSwitch) flush(cc []typeClause, compiled *ir.Nodes) {
1067 if len(cc) == 0 {
1068 return
1069 }
1070
1071 slices.SortFunc(cc, func(a, b typeClause) int { return cmp.Compare(a.hash, b.hash) })
1072
1073
1074 merged := cc[:1]
1075 for _, c := range cc[1:] {
1076 last := &merged[len(merged)-1]
1077 if last.hash == c.hash {
1078 last.body.Append(c.body.Take()...)
1079 } else {
1080 merged = append(merged, c)
1081 }
1082 }
1083 cc = merged
1084
1085 if s.tryJumpTable(cc, compiled) {
1086 return
1087 }
1088 binarySearch(len(cc), compiled,
1089 func(i int) ir.Node {
1090 return ir.NewBinaryExpr(base.Pos, ir.OLE, s.hashName, ir.NewInt(base.Pos, int64(cc[i-1].hash)))
1091 },
1092 func(i int, nif *ir.IfStmt) {
1093
1094
1095 c := cc[i]
1096 nif.Cond = ir.NewBinaryExpr(base.Pos, ir.OEQ, s.hashName, ir.NewInt(base.Pos, int64(c.hash)))
1097 nif.Body.Append(c.body.Take()...)
1098 },
1099 )
1100 }
1101
1102
1103 func (s *typeSwitch) tryJumpTable(cc []typeClause, out *ir.Nodes) bool {
1104 const minCases = 5
1105 if base.Flag.N != 0 || !ssagen.Arch.LinkArch.CanJumpTable || base.Ctxt.Retpoline {
1106 return false
1107 }
1108 if len(cc) < minCases {
1109 return false
1110 }
1111 hashes := make([]uint32, len(cc))
1112
1113
1114 b0 := bits.Len(uint(len(cc) - 1))
1115 for b := b0; b < b0+3; b++ {
1116 pickI:
1117 for i := 0; i <= 32-b; i++ {
1118
1119
1120 hashes = hashes[:0]
1121 for _, c := range cc {
1122 h := c.hash >> i & (1<<b - 1)
1123 hashes = append(hashes, h)
1124 }
1125
1126 slices.Sort(hashes)
1127 for j := 1; j < len(hashes); j++ {
1128 if hashes[j] == hashes[j-1] {
1129
1130 continue pickI
1131 }
1132 }
1133
1134
1135 h := s.hashName
1136 if i != 0 {
1137 h = ir.NewBinaryExpr(base.Pos, ir.ORSH, h, ir.NewInt(base.Pos, int64(i)))
1138 }
1139 h = ir.NewBinaryExpr(base.Pos, ir.OAND, h, ir.NewInt(base.Pos, int64(1<<b-1)))
1140 h = typecheck.Expr(h)
1141
1142
1143 jt := ir.NewJumpTableStmt(base.Pos, h)
1144 jt.Cases = make([]constant.Value, 1<<b)
1145 jt.Targets = make([]*types.Sym, 1<<b)
1146 out.Append(jt)
1147
1148
1149 noMatch := typecheck.AutoLabel(".s")
1150 for j := 0; j < 1<<b; j++ {
1151 jt.Cases[j] = constant.MakeInt64(int64(j))
1152 jt.Targets[j] = noMatch
1153 }
1154
1155
1156 out.Append(ir.NewBranchStmt(base.Pos, ir.OGOTO, noMatch))
1157
1158
1159 for _, c := range cc {
1160 h := c.hash >> i & (1<<b - 1)
1161 label := typecheck.AutoLabel(".s")
1162 jt.Targets[h] = label
1163 out.Append(ir.NewLabelStmt(base.Pos, label))
1164 out.Append(c.body...)
1165
1166 out.Append(ir.NewBranchStmt(base.Pos, ir.OGOTO, noMatch))
1167 }
1168
1169 out.Append(ir.NewLabelStmt(base.Pos, noMatch))
1170 return true
1171 }
1172 }
1173
1174 return false
1175 }
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186 func binarySearch(n int, out *ir.Nodes, less func(i int) ir.Node, leaf func(i int, nif *ir.IfStmt)) {
1187 const binarySearchMin = 4
1188
1189 var do func(lo, hi int, out *ir.Nodes)
1190 do = func(lo, hi int, out *ir.Nodes) {
1191 n := hi - lo
1192 if n < binarySearchMin {
1193 for i := lo; i < hi; i++ {
1194 nif := ir.NewIfStmt(base.Pos, nil, nil, nil)
1195 leaf(i, nif)
1196 base.Pos = base.Pos.WithNotStmt()
1197 nif.Cond = typecheck.Expr(nif.Cond)
1198 nif.Cond = typecheck.DefaultLit(nif.Cond, nil)
1199 out.Append(nif)
1200 out = &nif.Else
1201 }
1202 return
1203 }
1204
1205 half := lo + n/2
1206 nif := ir.NewIfStmt(base.Pos, nil, nil, nil)
1207 nif.Cond = less(half)
1208 base.Pos = base.Pos.WithNotStmt()
1209 nif.Cond = typecheck.Expr(nif.Cond)
1210 nif.Cond = typecheck.DefaultLit(nif.Cond, nil)
1211 do(lo, half, &nif.Body)
1212 do(half, hi, &nif.Else)
1213 out.Append(nif)
1214 }
1215
1216 do(0, n, out)
1217 }
1218
1219 func stringSearch(expr ir.Node, cc []exprClause, out *ir.Nodes) {
1220 if len(cc) < 4 {
1221
1222 for _, c := range cc {
1223 nif := ir.NewIfStmt(base.Pos.WithNotStmt(), typecheck.DefaultLit(typecheck.Expr(c.test(expr)), nil), []ir.Node{c.jmp}, nil)
1224 out.Append(nif)
1225 out = &nif.Else
1226 }
1227 return
1228 }
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249 n := len(ir.StringVal(cc[0].lo))
1250 bestScore := int64(0)
1251 bestIdx := 0
1252 bestByte := int8(0)
1253 for idx := 0; idx < n; idx++ {
1254 for b := int8(-128); b < 127; b++ {
1255 le := 0
1256 for _, c := range cc {
1257 s := ir.StringVal(c.lo)
1258 if int8(s[idx]) <= b {
1259 le++
1260 }
1261 }
1262 score := int64(le) * int64(len(cc)-le)
1263 if score > bestScore {
1264 bestScore = score
1265 bestIdx = idx
1266 bestByte = b
1267 }
1268 }
1269 }
1270
1271
1272
1273
1274 if bestScore == 0 {
1275 base.Fatalf("unable to split string set")
1276 }
1277
1278
1279 slice := ir.NewConvExpr(base.Pos, ir.OSTR2BYTESTMP, types.NewSlice(types.Types[types.TINT8]), expr)
1280 slice.SetTypecheck(1)
1281 slice.MarkNonNil()
1282
1283 load := ir.NewIndexExpr(base.Pos, slice, ir.NewInt(base.Pos, int64(bestIdx)))
1284
1285 cmp := ir.Node(ir.NewBinaryExpr(base.Pos, ir.OLE, load, ir.NewInt(base.Pos, int64(bestByte))))
1286 cmp = typecheck.DefaultLit(typecheck.Expr(cmp), nil)
1287 nif := ir.NewIfStmt(base.Pos, cmp, nil, nil)
1288
1289 var le []exprClause
1290 var gt []exprClause
1291 for _, c := range cc {
1292 s := ir.StringVal(c.lo)
1293 if int8(s[bestIdx]) <= bestByte {
1294 le = append(le, c)
1295 } else {
1296 gt = append(gt, c)
1297 }
1298 }
1299 stringSearch(expr, le, &nif.Body)
1300 stringSearch(expr, gt, &nif.Else)
1301 out.Append(nif)
1302
1303
1304 }
1305
View as plain text