1
2
3
4
5 package main
6
7 import (
8 "bytes"
9 "fmt"
10 "slices"
11 "strings"
12 "text/template"
13 )
14
15 type tplRuleData struct {
16 tplName string
17 GoOp string
18 GoType string
19 Args string
20 Asm string
21 ArgsOut string
22 MaskInConvert string
23 MaskOutConvert string
24 ElementSize int
25 Size int
26 ArgsLoadAddr string
27 ArgsAddr string
28 FeatCheck string
29 }
30
31 var (
32 ruleTemplates = template.Must(template.New("simdRules").Parse(`
33 {{define "pureVreg"}}({{.GoOp}}{{.GoType}} {{.Args}}) => ({{.Asm}} {{.ArgsOut}})
34 {{end}}
35 {{define "maskIn"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => ({{.Asm}} {{.ArgsOut}} ({{.MaskInConvert}} <types.TypeMask> mask))
36 {{end}}
37 {{define "maskOut"}}({{.GoOp}}{{.GoType}} {{.Args}}) => ({{.MaskOutConvert}} ({{.Asm}} {{.ArgsOut}}))
38 {{end}}
39 {{define "maskInMaskOut"}}({{.GoOp}}{{.GoType}} {{.Args}} mask) => ({{.MaskOutConvert}} ({{.Asm}} {{.ArgsOut}} ({{.MaskInConvert}} <types.TypeMask> mask)))
40 {{end}}
41 {{define "sftimm"}}({{.Asm}} x (MOVQconst [c])) => ({{.Asm}}const [uint8(c)] x)
42 {{end}}
43 {{define "masksftimm"}}({{.Asm}} x (MOVQconst [c]) mask) => ({{.Asm}}const [uint8(c)] x mask)
44 {{end}}
45 {{define "vregMem"}}({{.Asm}} {{.ArgsLoadAddr}}) && canMergeLoad(v, l) && clobber(l) => ({{.Asm}}load {{.ArgsAddr}})
46 {{end}}
47 {{define "vregMemFeatCheck"}}({{.Asm}} {{.ArgsLoadAddr}}) && {{.FeatCheck}} && canMergeLoad(v, l) && clobber(l)=> ({{.Asm}}load {{.ArgsAddr}})
48 {{end}}
49 `))
50 )
51
52 func (d tplRuleData) MaskOptimization(asmCheck map[string]bool) string {
53 asmNoMask := d.Asm
54 if i := strings.Index(asmNoMask, "Masked"); i == -1 {
55 return ""
56 }
57 asmNoMask = strings.ReplaceAll(asmNoMask, "Masked", "")
58 if asmCheck[asmNoMask] == false {
59 return ""
60 }
61
62 for _, nope := range []string{"VMOVDQU", "VPCOMPRESS", "VCOMPRESS", "VPEXPAND", "VEXPAND", "VPBLENDM", "VMOVUP"} {
63 if strings.HasPrefix(asmNoMask, nope) {
64 return ""
65 }
66 }
67
68 size := asmNoMask[len(asmNoMask)-3:]
69 if strings.HasSuffix(asmNoMask, "const") {
70 sufLen := len("128const")
71 size = asmNoMask[len(asmNoMask)-sufLen:][:3]
72 }
73 switch size {
74 case "128", "256", "512":
75 default:
76 panic("Unexpected operation size on " + d.Asm)
77 }
78
79 switch d.ElementSize {
80 case 8, 16, 32, 64:
81 default:
82 panic(fmt.Errorf("Unexpected operation width %d on %v", d.ElementSize, d.Asm))
83 }
84
85 return fmt.Sprintf("(VMOVDQU%dMasked%s (%s %s) mask) => (%s %s mask)\n", d.ElementSize, size, asmNoMask, d.Args, d.Asm, d.Args)
86 }
87
88
89 var tmplOrder = map[string]int{
90 "masksftimm": 0,
91 "sftimm": 1,
92 "maskInMaskOut": 2,
93 "maskOut": 3,
94 "maskIn": 4,
95 "pureVreg": 5,
96 "vregMem": 6,
97 }
98
99 func compareTplRuleData(x, y tplRuleData) int {
100 if c := compareNatural(x.GoOp, y.GoOp); c != 0 {
101 return c
102 }
103 if c := compareNatural(x.GoType, y.GoType); c != 0 {
104 return c
105 }
106 if c := compareNatural(x.Args, y.Args); c != 0 {
107 return c
108 }
109 if x.tplName == y.tplName {
110 return 0
111 }
112 xo, xok := tmplOrder[x.tplName]
113 yo, yok := tmplOrder[y.tplName]
114 if !xok {
115 panic(fmt.Errorf("Unexpected template name %s, please add to tmplOrder", x.tplName))
116 }
117 if !yok {
118 panic(fmt.Errorf("Unexpected template name %s, please add to tmplOrder", y.tplName))
119 }
120 return xo - yo
121 }
122
123
124
125 func writeSIMDRules(ops []Operation) *bytes.Buffer {
126 buffer := new(bytes.Buffer)
127 buffer.WriteString(generatedHeader + "\n")
128
129
130 maskedMergeOpts := make(map[string]string)
131 s2n := map[int]string{8: "B", 16: "W", 32: "D", 64: "Q"}
132 asmCheck := map[string]bool{}
133 sftimmCheck := map[string]bool{}
134 var allData []tplRuleData
135 var optData []tplRuleData
136 var memOptData []tplRuleData
137 memOpSeen := make(map[string]bool)
138
139 for _, opr := range ops {
140 opInShape, opOutShape, maskType, immType, gOp := opr.shape()
141 asm := machineOpName(maskType, gOp)
142 vregInCnt := len(gOp.In)
143 if maskType == OneMask {
144 vregInCnt--
145 }
146
147 data := tplRuleData{
148 GoOp: gOp.Go,
149 Asm: asm,
150 }
151
152 if vregInCnt == 1 {
153 data.Args = "x"
154 data.ArgsOut = data.Args
155 } else if vregInCnt == 2 {
156 data.Args = "x y"
157 data.ArgsOut = data.Args
158 } else if vregInCnt == 3 {
159 data.Args = "x y z"
160 data.ArgsOut = data.Args
161 } else {
162 panic(fmt.Errorf("simdgen does not support more than 3 vreg in inputs"))
163 }
164 if immType == ConstImm {
165 data.ArgsOut = fmt.Sprintf("[%s] %s", *opr.In[0].Const, data.ArgsOut)
166 } else if immType == VarImm {
167 data.Args = fmt.Sprintf("[a] %s", data.Args)
168 data.ArgsOut = fmt.Sprintf("[a] %s", data.ArgsOut)
169 } else if immType == ConstVarImm {
170 data.Args = fmt.Sprintf("[a] %s", data.Args)
171 data.ArgsOut = fmt.Sprintf("[a+%s] %s", *opr.In[0].Const, data.ArgsOut)
172 }
173
174 goType := func(op Operation) string {
175 if op.OperandOrder != nil {
176 switch *op.OperandOrder {
177 case "21Type1", "231Type1":
178
179 return *op.In[1].Go
180 }
181 }
182 return *op.In[0].Go
183 }
184 var tplName string
185
186 if opOutShape == OneVregOut || opOutShape == OneVregOutAtIn || gOp.Out[0].OverwriteClass != nil {
187 switch opInShape {
188 case OneImmIn:
189 tplName = "pureVreg"
190 data.GoType = goType(gOp)
191 case PureVregIn:
192 tplName = "pureVreg"
193 data.GoType = goType(gOp)
194 case OneKmaskImmIn:
195 fallthrough
196 case OneKmaskIn:
197 tplName = "maskIn"
198 data.GoType = goType(gOp)
199 rearIdx := len(gOp.In) - 1
200
201 width := *gOp.In[rearIdx].ElemBits
202 data.MaskInConvert = fmt.Sprintf("VPMOVVec%dx%dToM", width, *gOp.In[rearIdx].Lanes)
203 data.ElementSize = width
204 case PureKmaskIn:
205 panic(fmt.Errorf("simdgen does not support pure k mask instructions, they should be generated by compiler optimizations"))
206 }
207 } else if opOutShape == OneGregOut {
208 tplName = "pureVreg"
209 data.GoType = goType(gOp)
210 } else {
211
212 data.MaskOutConvert = fmt.Sprintf("VPMOVMToVec%dx%d", *gOp.Out[0].ElemBits, *gOp.In[0].Lanes)
213 switch opInShape {
214 case OneImmIn:
215 fallthrough
216 case PureVregIn:
217 tplName = "maskOut"
218 data.GoType = goType(gOp)
219 case OneKmaskImmIn:
220 fallthrough
221 case OneKmaskIn:
222 tplName = "maskInMaskOut"
223 data.GoType = goType(gOp)
224 rearIdx := len(gOp.In) - 1
225 data.MaskInConvert = fmt.Sprintf("VPMOVVec%dx%dToM", *gOp.In[rearIdx].ElemBits, *gOp.In[rearIdx].Lanes)
226 case PureKmaskIn:
227 panic(fmt.Errorf("simdgen does not support pure k mask instructions, they should be generated by compiler optimizations"))
228 }
229 }
230
231 if gOp.SpecialLower != nil {
232 if *gOp.SpecialLower == "sftimm" {
233 if !sftimmCheck[data.Asm] {
234 sftimmCheck[data.Asm] = true
235 sftImmData := data
236 if tplName == "maskIn" {
237 sftImmData.tplName = "masksftimm"
238 } else {
239 sftImmData.tplName = "sftimm"
240 }
241 allData = append(allData, sftImmData)
242 asmCheck[sftImmData.Asm+"const"] = true
243 }
244 } else {
245 panic("simdgen sees unknwon special lower " + *gOp.SpecialLower + ", maybe implement it?")
246 }
247 }
248 if gOp.MemFeatures != nil && *gOp.MemFeatures == "vbcst" {
249
250 selected := true
251 for _, a := range gOp.In {
252 if a.TreatLikeAScalarOfSize != nil {
253 selected = false
254 break
255 }
256 }
257 if _, ok := memOpSeen[data.Asm]; ok {
258 selected = false
259 }
260 if selected {
261 memOpSeen[data.Asm] = true
262 lastVreg := gOp.In[vregInCnt-1]
263
264 if lastVreg.Class != "vreg" {
265 panic(fmt.Errorf("simdgen expects vbcst replaced operand to be a vreg, but %v found", lastVreg))
266 }
267 memOpData := data
268
269 origArgs := data.Args[:len(data.Args)-1]
270
271 immArg := ""
272 immArgCombineOff := " [off] "
273 if immType != NoImm && immType != InvalidImm {
274 _, after, found := strings.Cut(origArgs, "]")
275 if found {
276 origArgs = after
277 }
278 immArg = "[c] "
279 immArgCombineOff = " [makeValAndOff(int32(uint8(c)),off)] "
280 }
281 memOpData.ArgsLoadAddr = immArg + origArgs + fmt.Sprintf("l:(VMOVDQUload%d {sym} [off] ptr mem)", *lastVreg.Bits)
282
283 memOpData.ArgsAddr = "{sym}" + immArgCombineOff + origArgs + "ptr"
284 if maskType == OneMask {
285 memOpData.ArgsAddr += " mask"
286 memOpData.ArgsLoadAddr += " mask"
287 }
288 memOpData.ArgsAddr += " mem"
289 if gOp.MemFeaturesData != nil {
290 _, feat2 := getVbcstData(*gOp.MemFeaturesData)
291 knownFeatChecks := map[string]string{
292 "AVX": "v.Block.CPUfeatures.hasFeature(CPUavx)",
293 "AVX2": "v.Block.CPUfeatures.hasFeature(CPUavx2)",
294 "AVX512": "v.Block.CPUfeatures.hasFeature(CPUavx512)",
295 }
296 memOpData.FeatCheck = knownFeatChecks[feat2]
297 memOpData.tplName = "vregMemFeatCheck"
298 } else {
299 memOpData.tplName = "vregMem"
300 }
301 memOptData = append(memOptData, memOpData)
302 asmCheck[memOpData.Asm+"load"] = true
303 }
304 }
305
306 if gOp.hasMaskedMerging(maskType, opOutShape) {
307
308 maskElem := gOp.In[len(gOp.In)-1]
309 if maskElem.Bits == nil {
310 panic("mask has no bits")
311 }
312 if maskElem.ElemBits == nil {
313 panic("mask has no elemBits")
314 }
315 if maskElem.Lanes == nil {
316 panic("mask has no lanes")
317 }
318 switch *maskElem.Bits {
319 case 128, 256:
320
321 noMaskName := machineOpName(NoMask, gOp)
322 ruleExisting, ok := maskedMergeOpts[noMaskName]
323 rule := fmt.Sprintf("(VPBLENDVB%d dst (%s %s) mask) && v.Block.CPUfeatures.hasFeature(CPUavx512) => (%sMerging dst %s (VPMOVVec%dx%dToM <types.TypeMask> mask))\n",
324 *maskElem.Bits, noMaskName, data.Args, data.Asm, data.Args, *maskElem.ElemBits, *maskElem.Lanes)
325 if ok && ruleExisting != rule {
326 panic(fmt.Sprintf("multiple masked merge rules for one op:\n%s\n%s\n", ruleExisting, rule))
327 } else {
328 maskedMergeOpts[noMaskName] = rule
329 }
330 case 512:
331
332 noMaskName := machineOpName(NoMask, gOp)
333 ruleExisting, ok := maskedMergeOpts[noMaskName]
334 rule := fmt.Sprintf("(VPBLENDM%sMasked%d dst (%s %s) mask) => (%sMerging dst %s mask)\n",
335 s2n[*maskElem.ElemBits], *maskElem.Bits, noMaskName, data.Args, data.Asm, data.Args)
336 if ok && ruleExisting != rule {
337 panic(fmt.Sprintf("multiple masked merge rules for one op:\n%s\n%s\n", ruleExisting, rule))
338 } else {
339 maskedMergeOpts[noMaskName] = rule
340 }
341 }
342 }
343
344 if tplName == "pureVreg" && data.Args == data.ArgsOut {
345 data.Args = "..."
346 data.ArgsOut = "..."
347 }
348 data.tplName = tplName
349 if opr.NoGenericOps != nil && *opr.NoGenericOps == "true" ||
350 opr.SkipMaskedMethod() {
351 optData = append(optData, data)
352 continue
353 }
354 allData = append(allData, data)
355 asmCheck[data.Asm] = true
356 }
357
358 slices.SortFunc(allData, compareTplRuleData)
359
360 for _, data := range allData {
361 if err := ruleTemplates.ExecuteTemplate(buffer, data.tplName, data); err != nil {
362 panic(fmt.Errorf("failed to execute template %s for %s: %w", data.tplName, data.GoOp+data.GoType, err))
363 }
364 }
365
366 seen := make(map[string]bool)
367
368 for _, data := range optData {
369 if data.tplName == "maskIn" {
370 rule := data.MaskOptimization(asmCheck)
371 if seen[rule] {
372 continue
373 }
374 seen[rule] = true
375 buffer.WriteString(rule)
376 }
377 }
378
379 maskedMergeOptsRules := []string{}
380 for asm, rule := range maskedMergeOpts {
381 if !asmCheck[asm] {
382 continue
383 }
384 maskedMergeOptsRules = append(maskedMergeOptsRules, rule)
385 }
386 slices.Sort(maskedMergeOptsRules)
387 for _, rule := range maskedMergeOptsRules {
388 buffer.WriteString(rule)
389 }
390
391 for _, data := range memOptData {
392 if err := ruleTemplates.ExecuteTemplate(buffer, data.tplName, data); err != nil {
393 panic(fmt.Errorf("failed to execute template %s for %s: %w", data.tplName, data.Asm, err))
394 }
395 }
396
397 return buffer
398 }
399
View as plain text