Source file src/simd/archsimd/_gen/simdgen/gen_simdrules.go

     1  // Copyright 2025 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package main
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  	"slices"
    11  	"strings"
    12  	"text/template"
    13  )
    14  
    15  type tplRuleData struct {
    16  	tplName        string // e.g. "sftimm"
    17  	GoOp           string // e.g. "ShiftAllLeft"
    18  	GoType         string // e.g. "Uint32x8"
    19  	Args           string // e.g. "x y"
    20  	Asm            string // e.g. "VPSLLD256"
    21  	ArgsOut        string // e.g. "x y"
    22  	MaskInConvert  string // e.g. "VPMOVVec32x8ToM"
    23  	MaskOutConvert string // e.g. "VPMOVMToVec32x8"
    24  	ElementSize    int    // e.g. 32
    25  	Size           int    // e.g. 128
    26  	ArgsLoadAddr   string // [Args] with its last vreg arg being a concrete "(VMOVDQUload* ptr mem)", and might contain mask.
    27  	ArgsAddr       string // [Args] with its last vreg arg being replaced by "ptr", and might contain mask, and with a "mem" at the end.
    28  	FeatCheck      string // e.g. "v.Block.CPUfeatures.hasFeature(CPUavx512)" -- for a ssa/_gen rules file.
    29  }
    30  
    31  var (
    32  	ruleTemplates = template.Must(template.New("simdRules").Parse(`
    33  {{define "pureVreg"}}({{.GoOp}}{{.GoType}} {{.Args}}) => ({{.Asm}} {{.ArgsOut}})
    34  {{end}}
    35  {{define "maskIn"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => ({{.Asm}} {{.ArgsOut}} ({{.MaskInConvert}} <types.TypeMask> mask))
    36  {{end}}
    37  {{define "maskOut"}}({{.GoOp}}{{.GoType}} {{.Args}}) => ({{.MaskOutConvert}} ({{.Asm}} {{.ArgsOut}}))
    38  {{end}}
    39  {{define "maskInMaskOut"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => ({{.MaskOutConvert}} ({{.Asm}} {{.ArgsOut}} ({{.MaskInConvert}} <types.TypeMask> mask)))
    40  {{end}}
    41  {{define "sftimm"}}({{.Asm}} x (MOVQconst [c])) => ({{.Asm}}const [uint8(c)] x)
    42  {{end}}
    43  {{define "masksftimm"}}({{.Asm}} x (MOVQconst [c]) mask) => ({{.Asm}}const [uint8(c)] x mask)
    44  {{end}}
    45  {{define "vregMem"}}({{.Asm}} {{.ArgsLoadAddr}}) && canMergeLoad(v, l) && clobber(l) => ({{.Asm}}load {{.ArgsAddr}})
    46  {{end}}
    47  {{define "vregMemFeatCheck"}}({{.Asm}} {{.ArgsLoadAddr}}) && {{.FeatCheck}} && canMergeLoad(v, l) && clobber(l)=> ({{.Asm}}load {{.ArgsAddr}})
    48  {{end}}
    49  `))
    50  )
    51  
    52  func (d tplRuleData) MaskOptimization(asmCheck map[string]bool) string {
    53  	asmNoMask := d.Asm
    54  	if i := strings.Index(asmNoMask, "Masked"); i == -1 {
    55  		return ""
    56  	}
    57  	asmNoMask = strings.ReplaceAll(asmNoMask, "Masked", "")
    58  	if asmCheck[asmNoMask] == false {
    59  		return ""
    60  	}
    61  
    62  	for _, nope := range []string{"VMOVDQU", "VPCOMPRESS", "VCOMPRESS", "VPEXPAND", "VEXPAND", "VPBLENDM", "VMOVUP"} {
    63  		if strings.HasPrefix(asmNoMask, nope) {
    64  			return ""
    65  		}
    66  	}
    67  
    68  	size := asmNoMask[len(asmNoMask)-3:]
    69  	if strings.HasSuffix(asmNoMask, "const") {
    70  		sufLen := len("128const")
    71  		size = asmNoMask[len(asmNoMask)-sufLen:][:3]
    72  	}
    73  	switch size {
    74  	case "128", "256", "512":
    75  	default:
    76  		panic("Unexpected operation size on " + d.Asm)
    77  	}
    78  
    79  	switch d.ElementSize {
    80  	case 8, 16, 32, 64:
    81  	default:
    82  		panic(fmt.Errorf("Unexpected operation width %d on %v", d.ElementSize, d.Asm))
    83  	}
    84  
    85  	return fmt.Sprintf("(VMOVDQU%dMasked%s (%s %s) mask) => (%s %s mask)\n", d.ElementSize, size, asmNoMask, d.Args, d.Asm, d.Args)
    86  }
    87  
    88  // SSA rewrite rules need to appear in a most-to-least-specific order.  This works for that.
    89  var tmplOrder = map[string]int{
    90  	"masksftimm":    0,
    91  	"sftimm":        1,
    92  	"maskInMaskOut": 2,
    93  	"maskOut":       3,
    94  	"maskIn":        4,
    95  	"pureVreg":      5,
    96  	"vregMem":       6,
    97  }
    98  
    99  func compareTplRuleData(x, y tplRuleData) int {
   100  	if c := compareNatural(x.GoOp, y.GoOp); c != 0 {
   101  		return c
   102  	}
   103  	if c := compareNatural(x.GoType, y.GoType); c != 0 {
   104  		return c
   105  	}
   106  	if c := compareNatural(x.Args, y.Args); c != 0 {
   107  		return c
   108  	}
   109  	if x.tplName == y.tplName {
   110  		return 0
   111  	}
   112  	xo, xok := tmplOrder[x.tplName]
   113  	yo, yok := tmplOrder[y.tplName]
   114  	if !xok {
   115  		panic(fmt.Errorf("Unexpected template name %s, please add to tmplOrder", x.tplName))
   116  	}
   117  	if !yok {
   118  		panic(fmt.Errorf("Unexpected template name %s, please add to tmplOrder", y.tplName))
   119  	}
   120  	return xo - yo
   121  }
   122  
   123  // writeSIMDRules generates the lowering and rewrite rules for ssa and writes it to simdAMD64.rules
   124  // within the specified directory.
   125  func writeSIMDRules(ops []Operation) *bytes.Buffer {
   126  	buffer := new(bytes.Buffer)
   127  	buffer.WriteString(generatedHeader + "\n")
   128  
   129  	// asm -> masked merging rules
   130  	maskedMergeOpts := make(map[string]string)
   131  	s2n := map[int]string{8: "B", 16: "W", 32: "D", 64: "Q"}
   132  	asmCheck := map[string]bool{}    // for masked merge optimizations.
   133  	sftimmCheck := map[string]bool{} // deduplicate sftimm rules
   134  	var allData []tplRuleData
   135  	var optData []tplRuleData    // for mask peephole optimizations, and other misc
   136  	var memOptData []tplRuleData // for memory peephole optimizations
   137  	memOpSeen := make(map[string]bool)
   138  
   139  	for _, opr := range ops {
   140  		opInShape, opOutShape, maskType, immType, gOp := opr.shape()
   141  		asm := machineOpName(maskType, gOp)
   142  		vregInCnt := len(gOp.In)
   143  		if maskType == OneMask {
   144  			vregInCnt--
   145  		}
   146  
   147  		data := tplRuleData{
   148  			GoOp: gOp.Go,
   149  			Asm:  asm,
   150  		}
   151  
   152  		if vregInCnt == 1 {
   153  			data.Args = "x"
   154  			data.ArgsOut = data.Args
   155  		} else if vregInCnt == 2 {
   156  			data.Args = "x y"
   157  			data.ArgsOut = data.Args
   158  		} else if vregInCnt == 3 {
   159  			data.Args = "x y z"
   160  			data.ArgsOut = data.Args
   161  		} else {
   162  			panic(fmt.Errorf("simdgen does not support more than 3 vreg in inputs"))
   163  		}
   164  		if immType == ConstImm {
   165  			data.ArgsOut = fmt.Sprintf("[%s] %s", *opr.In[0].Const, data.ArgsOut)
   166  		} else if immType == VarImm {
   167  			data.Args = fmt.Sprintf("[a] %s", data.Args)
   168  			data.ArgsOut = fmt.Sprintf("[a] %s", data.ArgsOut)
   169  		} else if immType == ConstVarImm {
   170  			data.Args = fmt.Sprintf("[a] %s", data.Args)
   171  			data.ArgsOut = fmt.Sprintf("[a+%s] %s", *opr.In[0].Const, data.ArgsOut)
   172  		}
   173  
   174  		goType := func(op Operation) string {
   175  			if op.OperandOrder != nil {
   176  				switch *op.OperandOrder {
   177  				case "21Type1", "231Type1":
   178  					// Permute uses operand[1] for method receiver.
   179  					return *op.In[1].Go
   180  				}
   181  			}
   182  			return *op.In[0].Go
   183  		}
   184  		var tplName string
   185  		// If class overwrite is happening, that's not really a mask but a vreg.
   186  		if opOutShape == OneVregOut || opOutShape == OneVregOutAtIn || gOp.Out[0].OverwriteClass != nil {
   187  			switch opInShape {
   188  			case OneImmIn:
   189  				tplName = "pureVreg"
   190  				data.GoType = goType(gOp)
   191  			case PureVregIn:
   192  				tplName = "pureVreg"
   193  				data.GoType = goType(gOp)
   194  			case OneKmaskImmIn:
   195  				fallthrough
   196  			case OneKmaskIn:
   197  				tplName = "maskIn"
   198  				data.GoType = goType(gOp)
   199  				rearIdx := len(gOp.In) - 1
   200  				// Mask is at the end.
   201  				width := *gOp.In[rearIdx].ElemBits
   202  				data.MaskInConvert = fmt.Sprintf("VPMOVVec%dx%dToM", width, *gOp.In[rearIdx].Lanes)
   203  				data.ElementSize = width
   204  			case PureKmaskIn:
   205  				panic(fmt.Errorf("simdgen does not support pure k mask instructions, they should be generated by compiler optimizations"))
   206  			}
   207  		} else if opOutShape == OneGregOut {
   208  			tplName = "pureVreg" // TODO this will be wrong
   209  			data.GoType = goType(gOp)
   210  		} else {
   211  			// OneKmaskOut case
   212  			data.MaskOutConvert = fmt.Sprintf("VPMOVMToVec%dx%d", *gOp.Out[0].ElemBits, *gOp.In[0].Lanes)
   213  			switch opInShape {
   214  			case OneImmIn:
   215  				fallthrough
   216  			case PureVregIn:
   217  				tplName = "maskOut"
   218  				data.GoType = goType(gOp)
   219  			case OneKmaskImmIn:
   220  				fallthrough
   221  			case OneKmaskIn:
   222  				tplName = "maskInMaskOut"
   223  				data.GoType = goType(gOp)
   224  				rearIdx := len(gOp.In) - 1
   225  				data.MaskInConvert = fmt.Sprintf("VPMOVVec%dx%dToM", *gOp.In[rearIdx].ElemBits, *gOp.In[rearIdx].Lanes)
   226  			case PureKmaskIn:
   227  				panic(fmt.Errorf("simdgen does not support pure k mask instructions, they should be generated by compiler optimizations"))
   228  			}
   229  		}
   230  
   231  		if gOp.SpecialLower != nil {
   232  			if *gOp.SpecialLower == "sftimm" {
   233  				if !sftimmCheck[data.Asm] {
   234  					sftimmCheck[data.Asm] = true
   235  					sftImmData := data
   236  					if tplName == "maskIn" {
   237  						sftImmData.tplName = "masksftimm"
   238  					} else {
   239  						sftImmData.tplName = "sftimm"
   240  					}
   241  					allData = append(allData, sftImmData)
   242  					asmCheck[sftImmData.Asm+"const"] = true
   243  				}
   244  			} else {
   245  				panic("simdgen sees unknwon special lower " + *gOp.SpecialLower + ", maybe implement it?")
   246  			}
   247  		}
   248  		if gOp.MemFeatures != nil && *gOp.MemFeatures == "vbcst" {
   249  			// sanity check
   250  			selected := true
   251  			for _, a := range gOp.In {
   252  				if a.TreatLikeAScalarOfSize != nil {
   253  					selected = false
   254  					break
   255  				}
   256  			}
   257  			if _, ok := memOpSeen[data.Asm]; ok {
   258  				selected = false
   259  			}
   260  			if selected {
   261  				memOpSeen[data.Asm] = true
   262  				lastVreg := gOp.In[vregInCnt-1]
   263  				// sanity check
   264  				if lastVreg.Class != "vreg" {
   265  					panic(fmt.Errorf("simdgen expects vbcst replaced operand to be a vreg, but %v found", lastVreg))
   266  				}
   267  				memOpData := data
   268  				// Remove the last vreg from the arg and change it to a load.
   269  				origArgs := data.Args[:len(data.Args)-1]
   270  				// Prepare imm args.
   271  				immArg := ""
   272  				immArgCombineOff := " [off] "
   273  				if immType != NoImm && immType != InvalidImm {
   274  					_, after, found := strings.Cut(origArgs, "]")
   275  					if found {
   276  						origArgs = after
   277  					}
   278  					immArg = "[c] "
   279  					immArgCombineOff = " [makeValAndOff(int32(uint8(c)),off)] "
   280  				}
   281  				memOpData.ArgsLoadAddr = immArg + origArgs + fmt.Sprintf("l:(VMOVDQUload%d {sym} [off] ptr mem)", *lastVreg.Bits)
   282  				// Remove the last vreg from the arg and change it to "ptr".
   283  				memOpData.ArgsAddr = "{sym}" + immArgCombineOff + origArgs + "ptr"
   284  				if maskType == OneMask {
   285  					memOpData.ArgsAddr += " mask"
   286  					memOpData.ArgsLoadAddr += " mask"
   287  				}
   288  				memOpData.ArgsAddr += " mem"
   289  				if gOp.MemFeaturesData != nil {
   290  					_, feat2 := getVbcstData(*gOp.MemFeaturesData)
   291  					knownFeatChecks := map[string]string{
   292  						"AVX":    "v.Block.CPUfeatures.hasFeature(CPUavx)",
   293  						"AVX2":   "v.Block.CPUfeatures.hasFeature(CPUavx2)",
   294  						"AVX512": "v.Block.CPUfeatures.hasFeature(CPUavx512)",
   295  					}
   296  					memOpData.FeatCheck = knownFeatChecks[feat2]
   297  					memOpData.tplName = "vregMemFeatCheck"
   298  				} else {
   299  					memOpData.tplName = "vregMem"
   300  				}
   301  				memOptData = append(memOptData, memOpData)
   302  				asmCheck[memOpData.Asm+"load"] = true
   303  			}
   304  		}
   305  		// Generate the masked merging optimization rules
   306  		if gOp.hasMaskedMerging(maskType, opOutShape) {
   307  			// TODO: handle customized operand order and special lower.
   308  			maskElem := gOp.In[len(gOp.In)-1]
   309  			if maskElem.Bits == nil {
   310  				panic("mask has no bits")
   311  			}
   312  			if maskElem.ElemBits == nil {
   313  				panic("mask has no elemBits")
   314  			}
   315  			if maskElem.Lanes == nil {
   316  				panic("mask has no lanes")
   317  			}
   318  			switch *maskElem.Bits {
   319  			case 128, 256:
   320  				// VPBLENDVB cases.
   321  				noMaskName := machineOpName(NoMask, gOp)
   322  				ruleExisting, ok := maskedMergeOpts[noMaskName]
   323  				rule := fmt.Sprintf("(VPBLENDVB%d dst (%s %s) mask) && v.Block.CPUfeatures.hasFeature(CPUavx512) => (%sMerging dst %s (VPMOVVec%dx%dToM <types.TypeMask> mask))\n",
   324  					*maskElem.Bits, noMaskName, data.Args, data.Asm, data.Args, *maskElem.ElemBits, *maskElem.Lanes)
   325  				if ok && ruleExisting != rule {
   326  					panic(fmt.Sprintf("multiple masked merge rules for one op:\n%s\n%s\n", ruleExisting, rule))
   327  				} else {
   328  					maskedMergeOpts[noMaskName] = rule
   329  				}
   330  			case 512:
   331  				// VPBLENDM[BWDQ] cases.
   332  				noMaskName := machineOpName(NoMask, gOp)
   333  				ruleExisting, ok := maskedMergeOpts[noMaskName]
   334  				rule := fmt.Sprintf("(VPBLENDM%sMasked%d dst (%s %s) mask) => (%sMerging dst %s mask)\n",
   335  					s2n[*maskElem.ElemBits], *maskElem.Bits, noMaskName, data.Args, data.Asm, data.Args)
   336  				if ok && ruleExisting != rule {
   337  					panic(fmt.Sprintf("multiple masked merge rules for one op:\n%s\n%s\n", ruleExisting, rule))
   338  				} else {
   339  					maskedMergeOpts[noMaskName] = rule
   340  				}
   341  			}
   342  		}
   343  
   344  		if tplName == "pureVreg" && data.Args == data.ArgsOut {
   345  			data.Args = "..."
   346  			data.ArgsOut = "..."
   347  		}
   348  		data.tplName = tplName
   349  		if opr.NoGenericOps != nil && *opr.NoGenericOps == "true" ||
   350  			opr.SkipMaskedMethod() {
   351  			optData = append(optData, data)
   352  			continue
   353  		}
   354  		allData = append(allData, data)
   355  		asmCheck[data.Asm] = true
   356  	}
   357  
   358  	slices.SortFunc(allData, compareTplRuleData)
   359  
   360  	for _, data := range allData {
   361  		if err := ruleTemplates.ExecuteTemplate(buffer, data.tplName, data); err != nil {
   362  			panic(fmt.Errorf("failed to execute template %s for %s: %w", data.tplName, data.GoOp+data.GoType, err))
   363  		}
   364  	}
   365  
   366  	seen := make(map[string]bool)
   367  
   368  	for _, data := range optData {
   369  		if data.tplName == "maskIn" {
   370  			rule := data.MaskOptimization(asmCheck)
   371  			if seen[rule] {
   372  				continue
   373  			}
   374  			seen[rule] = true
   375  			buffer.WriteString(rule)
   376  		}
   377  	}
   378  
   379  	maskedMergeOptsRules := []string{}
   380  	for asm, rule := range maskedMergeOpts {
   381  		if !asmCheck[asm] {
   382  			continue
   383  		}
   384  		maskedMergeOptsRules = append(maskedMergeOptsRules, rule)
   385  	}
   386  	slices.Sort(maskedMergeOptsRules)
   387  	for _, rule := range maskedMergeOptsRules {
   388  		buffer.WriteString(rule)
   389  	}
   390  
   391  	for _, data := range memOptData {
   392  		if err := ruleTemplates.ExecuteTemplate(buffer, data.tplName, data); err != nil {
   393  			panic(fmt.Errorf("failed to execute template %s for %s: %w", data.tplName, data.Asm, err))
   394  		}
   395  	}
   396  
   397  	return buffer
   398  }
   399  

View as plain text