1
2
3
4
5 package main
6
7
8
9
10 import (
11 "bufio"
12 "bytes"
13 "flag"
14 "fmt"
15 "go/format"
16 "io"
17 "os"
18 "strings"
19 "text/template"
20 )
21
22 type resultTypeFunc func(t string, w, c int) (ot string, ow int, oc int)
23
24
25 type shapes struct {
26 vecs []int
27 ints []int
28 uints []int
29 floats []int
30 output resultTypeFunc
31 }
32
33
34 type shapeAndTemplate struct {
35 s *shapes
36 t *template.Template
37 }
38
39 func (sat shapeAndTemplate) target(outType string, width int) shapeAndTemplate {
40 newSat := sat
41 newShape := *sat.s
42 newShape.output = func(t string, w, c int) (ot string, ow int, oc int) {
43 oc = c
44 if width*c > 512 {
45 oc = 512 / width
46 } else if width*c < 128 {
47 oc = 128 / width
48 }
49 return outType, width, oc
50 }
51 newSat.s = &newShape
52 return newSat
53 }
54
55 func (sat shapeAndTemplate) targetFixed(outType string, width, count int) shapeAndTemplate {
56 newSat := sat
57 newShape := *sat.s
58 newShape.output = func(t string, w, c int) (ot string, ow int, oc int) {
59 return outType, width, count
60 }
61 newSat.s = &newShape
62 return newSat
63 }
64
65 func (s *shapes) forAllShapes(f func(seq int, t, upperT string, w, c int, out io.Writer), out io.Writer) {
66 vecs := s.vecs
67 ints := s.ints
68 uints := s.uints
69 floats := s.floats
70 seq := 0
71 for _, v := range vecs {
72 for _, w := range ints {
73 c := v / w
74 f(seq, "int", "Int", w, c, out)
75 seq++
76 }
77 for _, w := range uints {
78 c := v / w
79 f(seq, "uint", "Uint", w, c, out)
80 seq++
81 }
82 for _, w := range floats {
83 c := v / w
84 f(seq, "float", "Float", w, c, out)
85 seq++
86 }
87 }
88 }
89
90 var allShapes = &shapes{
91 vecs: []int{128, 256, 512},
92 ints: []int{8, 16, 32, 64},
93 uints: []int{8, 16, 32, 64},
94 floats: []int{32, 64},
95 }
96
97 var intShapes = &shapes{
98 vecs: []int{128, 256, 512},
99 ints: []int{8, 16, 32, 64},
100 }
101
102 var uintShapes = &shapes{
103 vecs: []int{128, 256, 512},
104 uints: []int{8, 16, 32, 64},
105 }
106
107 var floatShapes = &shapes{
108 vecs: []int{128, 256, 512},
109 floats: []int{32, 64},
110 }
111
112 var integerShapes = &shapes{
113 vecs: []int{128, 256, 512},
114 ints: []int{8, 16, 32, 64},
115 uints: []int{8, 16, 32, 64},
116 }
117
118 var avx512Shapes = &shapes{
119 vecs: []int{512},
120 ints: []int{8, 16, 32, 64},
121 uints: []int{8, 16, 32, 64},
122 floats: []int{32, 64},
123 }
124
125 var avx2Shapes = &shapes{
126 vecs: []int{128, 256},
127 ints: []int{8, 16, 32, 64},
128 uints: []int{8, 16, 32, 64},
129 floats: []int{32, 64},
130 }
131
132 var avx2MaskedLoadShapes = &shapes{
133 vecs: []int{128, 256},
134 ints: []int{32, 64},
135 uints: []int{32, 64},
136 floats: []int{32, 64},
137 }
138
139 var avx2SmallLoadPunShapes = &shapes{
140
141 vecs: []int{128, 256},
142 uints: []int{8, 16},
143 }
144
145 var unaryFlaky = &shapes{
146 vecs: []int{128, 256, 512},
147 floats: []int{32, 64},
148 }
149
150 var ternaryFlaky = &shapes{
151 vecs: []int{128, 256, 512},
152 floats: []int{32},
153 }
154
155 var avx2SignedComparisons = &shapes{
156 vecs: []int{128, 256},
157 ints: []int{8, 16, 32, 64},
158 }
159
160 var avx2UnsignedComparisons = &shapes{
161 vecs: []int{128, 256},
162 uints: []int{8, 16, 32, 64},
163 }
164
165 type templateData struct {
166 VType string
167 AOrAn string
168 EWidth int
169 Vwidth int
170 Count int
171 WxC string
172 BxC string
173 Base string
174 Etype string
175 OxFF string
176
177 OVType string
178 OEtype string
179 OEType string
180 OCount int
181 }
182
183 func (t templateData) As128BitVec() string {
184 return fmt.Sprintf("%s%dx%d", t.Base, t.EWidth, 128/t.EWidth)
185 }
186
187 func oneTemplate(t *template.Template, baseType string, width, count int, out io.Writer, rtf resultTypeFunc) {
188 b := width * count
189 if b < 128 || b > 512 {
190 return
191 }
192
193 ot, ow, oc := baseType, width, count
194 if rtf != nil {
195 ot, ow, oc = rtf(ot, ow, oc)
196 if ow*oc > 512 || ow*oc < 128 || ow < 8 || ow > 64 {
197 return
198 }
199
200 if ot == "float" && ow < 32 {
201 return
202 }
203 }
204 ovType := fmt.Sprintf("%s%dx%d", strings.ToUpper(ot[:1])+ot[1:], ow, oc)
205 oeType := fmt.Sprintf("%s%d", ot, ow)
206 oEType := fmt.Sprintf("%s%d", strings.ToUpper(ot[:1])+ot[1:], ow)
207
208 wxc := fmt.Sprintf("%dx%d", width, count)
209 BaseType := strings.ToUpper(baseType[:1]) + baseType[1:]
210 vType := fmt.Sprintf("%s%s", BaseType, wxc)
211 eType := fmt.Sprintf("%s%d", baseType, width)
212
213 bxc := fmt.Sprintf("%dx%d", 8, count*(width/8))
214 aOrAn := "a"
215 if strings.Contains("aeiou", baseType[:1]) {
216 aOrAn = "an"
217 }
218 oxFF := fmt.Sprintf("0x%x", uint64((1<<count)-1))
219 t.Execute(out, templateData{
220 VType: vType,
221 AOrAn: aOrAn,
222 EWidth: width,
223 Vwidth: b,
224 Count: count,
225 WxC: wxc,
226 BxC: bxc,
227 Base: BaseType,
228 Etype: eType,
229 OxFF: oxFF,
230 OVType: ovType,
231 OEtype: oeType,
232 OCount: oc,
233 OEType: oEType,
234 })
235 }
236
237
238
239 func (sat shapeAndTemplate) forTemplates(out io.Writer) {
240 t, s := sat.t, sat.s
241 vecs := s.vecs
242 ints := s.ints
243 uints := s.uints
244 floats := s.floats
245 for _, v := range vecs {
246 for _, w := range ints {
247 c := v / w
248 oneTemplate(t, "int", w, c, out, sat.s.output)
249 }
250 for _, w := range uints {
251 c := v / w
252 oneTemplate(t, "uint", w, c, out, sat.s.output)
253 }
254 for _, w := range floats {
255 c := v / w
256 oneTemplate(t, "float", w, c, out, sat.s.output)
257 }
258 }
259 }
260
261 func prologue(s string, out io.Writer) {
262 fmt.Fprintf(out,
263 `// Code generated by '%s'; DO NOT EDIT.
264
265 //go:build goexperiment.simd
266
267 package archsimd
268
269 `, s)
270 }
271
272 func ssaPrologue(s string, out io.Writer) {
273 fmt.Fprintf(out,
274 `// Code generated by '%s'; DO NOT EDIT.
275
276 package ssa
277
278 `, s)
279 }
280
281 func unsafePrologue(s string, out io.Writer) {
282 fmt.Fprintf(out,
283 `// Code generated by '%s'; DO NOT EDIT.
284
285 //go:build goexperiment.simd
286
287 package archsimd
288
289 import "unsafe"
290
291 `, s)
292 }
293
294 func testPrologue(t, s string, out io.Writer) {
295 fmt.Fprintf(out,
296 `// Code generated by '%s'; DO NOT EDIT.
297
298 //go:build goexperiment.simd && amd64
299
300 // This file contains functions testing %s.
301 // Each function in this file is specialized for a
302 // particular simd type <BaseType><Width>x<Count>.
303
304 package simd_test
305
306 import (
307 "simd/archsimd"
308 "testing"
309 )
310
311 `, s, t)
312 }
313
314 func curryTestPrologue(t string) func(s string, out io.Writer) {
315 return func(s string, out io.Writer) {
316 testPrologue(t, s, out)
317 }
318 }
319
320 func templateOf(name, temp string) shapeAndTemplate {
321 return shapeAndTemplate{s: allShapes,
322 t: template.Must(template.New(name).Parse(temp))}
323 }
324
325 func shapedTemplateOf(s *shapes, name, temp string) shapeAndTemplate {
326 return shapeAndTemplate{s: s,
327 t: template.Must(template.New(name).Parse(temp))}
328 }
329
330 var sliceTemplate = templateOf("slice", `
331 // Load{{.VType}}Slice loads {{.AOrAn}} {{.VType}} from a slice of at least {{.Count}} {{.Etype}}s.
332 func Load{{.VType}}Slice(s []{{.Etype}}) {{.VType}} {
333 return Load{{.VType}}((*[{{.Count}}]{{.Etype}})(s))
334 }
335
336 // StoreSlice stores x into a slice of at least {{.Count}} {{.Etype}}s.
337 func (x {{.VType}}) StoreSlice(s []{{.Etype}}) {
338 x.Store((*[{{.Count}}]{{.Etype}})(s))
339 }
340 `)
341
342 var unaryTemplate = templateOf("unary_helpers", `
343 // test{{.VType}}Unary tests the simd unary method f against the expected behavior generated by want
344 func test{{.VType}}Unary(t *testing.T, f func(_ archsimd.{{.VType}}) archsimd.{{.VType}}, want func(_ []{{.Etype}}) []{{.Etype}}) {
345 n := {{.Count}}
346 t.Helper()
347 forSlice(t, {{.Etype}}s, n, func(x []{{.Etype}}) bool {
348 t.Helper()
349 a := archsimd.Load{{.VType}}Slice(x)
350 g := make([]{{.Etype}}, n)
351 f(a).StoreSlice(g)
352 w := want(x)
353 return checkSlicesLogInput(t, g, w, 0.0, func() {t.Helper(); t.Logf("x=%v", x)})
354 })
355 }
356 `)
357
358 var unaryFlakyTemplate = shapedTemplateOf(unaryFlaky, "unary_flaky_helpers", `
359 // test{{.VType}}UnaryFlaky tests the simd unary method f against the expected behavior generated by want,
360 // but using a flakiness parameter because we haven't exactly figured out how simd floating point works
361 func test{{.VType}}UnaryFlaky(t *testing.T, f func(x archsimd.{{.VType}}) archsimd.{{.VType}}, want func(x []{{.Etype}}) []{{.Etype}}, flakiness float64) {
362 n := {{.Count}}
363 t.Helper()
364 forSlice(t, {{.Etype}}s, n, func(x []{{.Etype}}) bool {
365 t.Helper()
366 a := archsimd.Load{{.VType}}Slice(x)
367 g := make([]{{.Etype}}, n)
368 f(a).StoreSlice(g)
369 w := want(x)
370 return checkSlicesLogInput(t, g, w, flakiness, func() {t.Helper(); t.Logf("x=%v", x)})
371 })
372 }
373 `)
374
375 var convertTemplate = templateOf("convert_helpers", `
376 // test{{.VType}}ConvertTo{{.OEType}} tests the simd conversion method f against the expected behavior generated by want.
377 // This is for count-preserving conversions, so if there is a change in size, then there is a change in vector width,
378 // (extended to at least 128 bits, or truncated to at most 512 bits).
379 func test{{.VType}}ConvertTo{{.OEType}}(t *testing.T, f func(x archsimd.{{.VType}}) archsimd.{{.OVType}}, want func(x []{{.Etype}}) []{{.OEtype}}) {
380 n := {{.Count}}
381 t.Helper()
382 forSlice(t, {{.Etype}}s, n, func(x []{{.Etype}}) bool {
383 t.Helper()
384 a := archsimd.Load{{.VType}}Slice(x)
385 g := make([]{{.OEtype}}, {{.OCount}})
386 f(a).StoreSlice(g)
387 w := want(x)
388 return checkSlicesLogInput(t, g, w, 0.0, func() {t.Helper(); t.Logf("x=%v", x)})
389 })
390 }
391 `)
392
393 var (
394
395
396
397 unaryToInt8 = convertTemplate.target("int", 8)
398 unaryToUint8 = convertTemplate.target("uint", 8)
399 unaryToInt16 = convertTemplate.target("int", 16)
400 unaryToUint16 = convertTemplate.target("uint", 16)
401 unaryToInt32 = convertTemplate.target("int", 32)
402 unaryToUint32 = convertTemplate.target("uint", 32)
403 unaryToInt64 = convertTemplate.target("int", 64)
404 unaryToUint64 = convertTemplate.target("uint", 64)
405 unaryToFloat32 = convertTemplate.target("float", 32)
406 unaryToFloat64 = convertTemplate.target("float", 64)
407 )
408
409 var convertLoTemplate = shapedTemplateOf(integerShapes, "convert_lo_helpers", `
410 // test{{.VType}}ConvertLoTo{{.OVType}} tests the simd conversion method f against the expected behavior generated by want.
411 // This converts only the low {{.OCount}} elements.
412 func test{{.VType}}ConvertLoTo{{.OVType}}(t *testing.T, f func(x archsimd.{{.VType}}) archsimd.{{.OVType}}, want func(x []{{.Etype}}) []{{.OEtype}}) {
413 n := {{.Count}}
414 t.Helper()
415 forSlice(t, {{.Etype}}s, n, func(x []{{.Etype}}) bool {
416 t.Helper()
417 a := archsimd.Load{{.VType}}Slice(x)
418 g := make([]{{.OEtype}}, {{.OCount}})
419 f(a).StoreSlice(g)
420 w := want(x)
421 return checkSlicesLogInput(t, g, w, 0.0, func() {t.Helper(); t.Logf("x=%v", x)})
422 })
423 }
424 `)
425
426 var (
427
428
429
430
431
432 unaryToInt64x2 = convertLoTemplate.targetFixed("int", 64, 2)
433 unaryToInt64x4 = convertLoTemplate.targetFixed("int", 64, 4)
434 unaryToUint64x2 = convertLoTemplate.targetFixed("uint", 64, 2)
435 unaryToUint64x4 = convertLoTemplate.targetFixed("uint", 64, 4)
436 unaryToInt32x4 = convertLoTemplate.targetFixed("int", 32, 4)
437 unaryToInt32x8 = convertLoTemplate.targetFixed("int", 32, 8)
438 unaryToUint32x4 = convertLoTemplate.targetFixed("uint", 32, 4)
439 unaryToUint32x8 = convertLoTemplate.targetFixed("uint", 32, 8)
440 unaryToInt16x8 = convertLoTemplate.targetFixed("int", 16, 8)
441 unaryToUint16x8 = convertLoTemplate.targetFixed("uint", 16, 8)
442 )
443
444 var binaryTemplate = templateOf("binary_helpers", `
445 // test{{.VType}}Binary tests the simd binary method f against the expected behavior generated by want
446 func test{{.VType}}Binary(t *testing.T, f func(_, _ archsimd.{{.VType}}) archsimd.{{.VType}}, want func(_, _ []{{.Etype}}) []{{.Etype}}) {
447 n := {{.Count}}
448 t.Helper()
449 forSlicePair(t, {{.Etype}}s, n, func(x, y []{{.Etype}}) bool {
450 t.Helper()
451 a := archsimd.Load{{.VType}}Slice(x)
452 b := archsimd.Load{{.VType}}Slice(y)
453 g := make([]{{.Etype}}, n)
454 f(a, b).StoreSlice(g)
455 w := want(x, y)
456 return checkSlicesLogInput(t, g, w, 0.0, func() {t.Helper(); t.Logf("x=%v", x); t.Logf("y=%v", y); })
457 })
458 }
459 `)
460
461 var ternaryTemplate = templateOf("ternary_helpers", `
462 // test{{.VType}}Ternary tests the simd ternary method f against the expected behavior generated by want
463 func test{{.VType}}Ternary(t *testing.T, f func(_, _, _ archsimd.{{.VType}}) archsimd.{{.VType}}, want func(_, _, _ []{{.Etype}}) []{{.Etype}}) {
464 n := {{.Count}}
465 t.Helper()
466 forSliceTriple(t, {{.Etype}}s, n, func(x, y, z []{{.Etype}}) bool {
467 t.Helper()
468 a := archsimd.Load{{.VType}}Slice(x)
469 b := archsimd.Load{{.VType}}Slice(y)
470 c := archsimd.Load{{.VType}}Slice(z)
471 g := make([]{{.Etype}}, n)
472 f(a, b, c).StoreSlice(g)
473 w := want(x, y, z)
474 return checkSlicesLogInput(t, g, w, 0.0, func() {t.Helper(); t.Logf("x=%v", x); t.Logf("y=%v", y); t.Logf("z=%v", z); })
475 })
476 }
477 `)
478
479 var ternaryFlakyTemplate = shapedTemplateOf(ternaryFlaky, "ternary_helpers", `
480 // test{{.VType}}TernaryFlaky tests the simd ternary method f against the expected behavior generated by want,
481 // but using a flakiness parameter because we haven't exactly figured out how simd floating point works
482 func test{{.VType}}TernaryFlaky(t *testing.T, f func(x, y, z archsimd.{{.VType}}) archsimd.{{.VType}}, want func(x, y, z []{{.Etype}}) []{{.Etype}}, flakiness float64) {
483 n := {{.Count}}
484 t.Helper()
485 forSliceTriple(t, {{.Etype}}s, n, func(x, y, z []{{.Etype}}) bool {
486 t.Helper()
487 a := archsimd.Load{{.VType}}Slice(x)
488 b := archsimd.Load{{.VType}}Slice(y)
489 c := archsimd.Load{{.VType}}Slice(z)
490 g := make([]{{.Etype}}, n)
491 f(a, b, c).StoreSlice(g)
492 w := want(x, y, z)
493 return checkSlicesLogInput(t, g, w, flakiness, func() {t.Helper(); t.Logf("x=%v", x); t.Logf("y=%v", y); t.Logf("z=%v", z); })
494 })
495 }
496 `)
497
498 var compareTemplate = templateOf("compare_helpers", `
499 // test{{.VType}}Compare tests the simd comparison method f against the expected behavior generated by want
500 func test{{.VType}}Compare(t *testing.T, f func(_, _ archsimd.{{.VType}}) archsimd.Mask{{.WxC}}, want func(_, _ []{{.Etype}}) []int64) {
501 n := {{.Count}}
502 t.Helper()
503 forSlicePair(t, {{.Etype}}s, n, func(x, y []{{.Etype}}) bool {
504 t.Helper()
505 a := archsimd.Load{{.VType}}Slice(x)
506 b := archsimd.Load{{.VType}}Slice(y)
507 g := make([]int{{.EWidth}}, n)
508 f(a, b).ToInt{{.WxC}}().StoreSlice(g)
509 w := want(x, y)
510 return checkSlicesLogInput(t, s64(g), w, 0.0, func() {t.Helper(); t.Logf("x=%v", x); t.Logf("y=%v", y); })
511 })
512 }
513 `)
514
515 var compareUnaryTemplate = shapedTemplateOf(floatShapes, "compare_unary_helpers", `
516 // test{{.VType}}UnaryCompare tests the simd unary comparison method f against the expected behavior generated by want
517 func test{{.VType}}UnaryCompare(t *testing.T, f func(x archsimd.{{.VType}}) archsimd.Mask{{.WxC}}, want func(x []{{.Etype}}) []int64) {
518 n := {{.Count}}
519 t.Helper()
520 forSlice(t, {{.Etype}}s, n, func(x []{{.Etype}}) bool {
521 t.Helper()
522 a := archsimd.Load{{.VType}}Slice(x)
523 g := make([]int{{.EWidth}}, n)
524 f(a).ToInt{{.WxC}}().StoreSlice(g)
525 w := want(x)
526 return checkSlicesLogInput(t, s64(g), w, 0.0, func() {t.Helper(); t.Logf("x=%v", x)})
527 })
528 }
529 `)
530
531
532 var compareMaskedTemplate = templateOf("comparemasked_helpers", `
533 // test{{.VType}}CompareMasked tests the simd masked comparison method f against the expected behavior generated by want
534 // The mask is applied to the output of want; anything not in the mask, is zeroed.
535 func test{{.VType}}CompareMasked(t *testing.T,
536 f func(_, _ archsimd.{{.VType}}, m archsimd.Mask{{.WxC}}) archsimd.Mask{{.WxC}},
537 want func(_, _ []{{.Etype}}) []int64) {
538 n := {{.Count}}
539 t.Helper()
540 forSlicePairMasked(t, {{.Etype}}s, n, func(x, y []{{.Etype}}, m []bool) bool {
541 t.Helper()
542 a := archsimd.Load{{.VType}}Slice(x)
543 b := archsimd.Load{{.VType}}Slice(y)
544 k := archsimd.LoadInt{{.WxC}}Slice(toVect[int{{.EWidth}}](m)).ToMask()
545 g := make([]int{{.EWidth}}, n)
546 f(a, b, k).ToInt{{.WxC}}().StoreSlice(g)
547 w := want(x, y)
548 for i := range m {
549 if !m[i] {
550 w[i] = 0
551 }
552 }
553 return checkSlicesLogInput(t, s64(g), w, 0.0, func() {t.Helper(); t.Logf("x=%v", x); t.Logf("y=%v", y); t.Logf("m=%v", m); })
554 })
555 }
556 `)
557
558 var avx512MaskedLoadSlicePartTemplate = shapedTemplateOf(avx512Shapes, "avx 512 load slice part", `
559 // Load{{.VType}}SlicePart loads a {{.VType}} from the slice s.
560 // If s has fewer than {{.Count}} elements, the remaining elements of the vector are filled with zeroes.
561 // If s has {{.Count}} or more elements, the function is equivalent to Load{{.VType}}Slice.
562 func Load{{.VType}}SlicePart(s []{{.Etype}}) {{.VType}} {
563 l := len(s)
564 if l >= {{.Count}} {
565 return Load{{.VType}}Slice(s)
566 }
567 if l == 0 {
568 var x {{.VType}}
569 return x
570 }
571 mask := Mask{{.WxC}}FromBits({{.OxFF}} >> ({{.Count}} - l))
572 return LoadMasked{{.VType}}(pa{{.VType}}(s), mask)
573 }
574
575 // StoreSlicePart stores the {{.Count}} elements of x into the slice s.
576 // It stores as many elements as will fit in s.
577 // If s has {{.Count}} or more elements, the method is equivalent to x.StoreSlice.
578 func (x {{.VType}}) StoreSlicePart(s []{{.Etype}}) {
579 l := len(s)
580 if l >= {{.Count}} {
581 x.StoreSlice(s)
582 return
583 }
584 if l == 0 {
585 return
586 }
587 mask := Mask{{.WxC}}FromBits({{.OxFF}} >> ({{.Count}} - l))
588 x.StoreMasked(pa{{.VType}}(s), mask)
589 }
590 `)
591
592 var avx2MaskedLoadSlicePartTemplate = shapedTemplateOf(avx2MaskedLoadShapes, "avx 2 load slice part", `
593 // Load{{.VType}}SlicePart loads a {{.VType}} from the slice s.
594 // If s has fewer than {{.Count}} elements, the remaining elements of the vector are filled with zeroes.
595 // If s has {{.Count}} or more elements, the function is equivalent to Load{{.VType}}Slice.
596 func Load{{.VType}}SlicePart(s []{{.Etype}}) {{.VType}} {
597 l := len(s)
598 if l >= {{.Count}} {
599 return Load{{.VType}}Slice(s)
600 }
601 if l == 0 {
602 var x {{.VType}}
603 return x
604 }
605 mask := vecMask{{.EWidth}}[len(vecMask{{.EWidth}})/2-l:]
606 return LoadMasked{{.VType}}(pa{{.VType}}(s), LoadInt{{.WxC}}Slice(mask).asMask())
607 }
608
609 // StoreSlicePart stores the {{.Count}} elements of x into the slice s.
610 // It stores as many elements as will fit in s.
611 // If s has {{.Count}} or more elements, the method is equivalent to x.StoreSlice.
612 func (x {{.VType}}) StoreSlicePart(s []{{.Etype}}) {
613 l := len(s)
614 if l >= {{.Count}} {
615 x.StoreSlice(s)
616 return
617 }
618 if l == 0 {
619 return
620 }
621 mask := vecMask{{.EWidth}}[len(vecMask{{.EWidth}})/2-l:]
622 x.StoreMasked(pa{{.VType}}(s), LoadInt{{.WxC}}Slice(mask).asMask())
623 }
624 `)
625
626 var avx2SmallLoadSlicePartTemplate = shapedTemplateOf(avx2SmallLoadPunShapes, "avx 2 small load slice part", `
627 // Load{{.VType}}SlicePart loads a {{.VType}} from the slice s.
628 // If s has fewer than {{.Count}} elements, the remaining elements of the vector are filled with zeroes.
629 // If s has {{.Count}} or more elements, the function is equivalent to Load{{.VType}}Slice.
630 func Load{{.VType}}SlicePart(s []{{.Etype}}) {{.VType}} {
631 if len(s) == 0 {
632 var zero {{.VType}}
633 return zero
634 }
635 t := unsafe.Slice((*int{{.EWidth}})(unsafe.Pointer(&s[0])), len(s))
636 return LoadInt{{.WxC}}SlicePart(t).As{{.VType}}()
637 }
638
639 // StoreSlicePart stores the {{.Count}} elements of x into the slice s.
640 // It stores as many elements as will fit in s.
641 // If s has {{.Count}} or more elements, the method is equivalent to x.StoreSlice.
642 func (x {{.VType}}) StoreSlicePart(s []{{.Etype}}) {
643 if len(s) == 0 {
644 return
645 }
646 t := unsafe.Slice((*int{{.EWidth}})(unsafe.Pointer(&s[0])), len(s))
647 x.AsInt{{.WxC}}().StoreSlicePart(t)
648 }
649 `)
650
651 func (t templateData) CPUfeature() string {
652 switch t.Vwidth {
653 case 128:
654 return "AVX"
655 case 256:
656 return "AVX2"
657 case 512:
658 return "AVX512"
659 }
660 panic(fmt.Errorf("unexpected vector width %d", t.Vwidth))
661 }
662
663 var avx2SignedComparisonsTemplate = shapedTemplateOf(avx2SignedComparisons, "avx2 signed comparisons", `
664 // Less returns a mask whose elements indicate whether x < y.
665 //
666 // Emulated, CPU Feature: {{.CPUfeature}}
667 func (x {{.VType}}) Less(y {{.VType}}) Mask{{.WxC}} {
668 return y.Greater(x)
669 }
670
671 // GreaterEqual returns a mask whose elements indicate whether x >= y.
672 //
673 // Emulated, CPU Feature: {{.CPUfeature}}
674 func (x {{.VType}}) GreaterEqual(y {{.VType}}) Mask{{.WxC}} {
675 ones := x.Equal(x).ToInt{{.WxC}}()
676 return y.Greater(x).ToInt{{.WxC}}().Xor(ones).asMask()
677 }
678
679 // LessEqual returns a mask whose elements indicate whether x <= y.
680 //
681 // Emulated, CPU Feature: {{.CPUfeature}}
682 func (x {{.VType}}) LessEqual(y {{.VType}}) Mask{{.WxC}} {
683 ones := x.Equal(x).ToInt{{.WxC}}()
684 return x.Greater(y).ToInt{{.WxC}}().Xor(ones).asMask()
685 }
686
687 // NotEqual returns a mask whose elements indicate whether x != y.
688 //
689 // Emulated, CPU Feature: {{.CPUfeature}}
690 func (x {{.VType}}) NotEqual(y {{.VType}}) Mask{{.WxC}} {
691 ones := x.Equal(x).ToInt{{.WxC}}()
692 return x.Equal(y).ToInt{{.WxC}}().Xor(ones).asMask()
693 }
694 `)
695
696 var bitWiseIntTemplate = shapedTemplateOf(intShapes, "bitwise int complement", `
697 // Not returns the bitwise complement of x.
698 //
699 // Emulated, CPU Feature: {{.CPUfeature}}
700 func (x {{.VType}}) Not() {{.VType}} {
701 return x.Xor(x.Equal(x).ToInt{{.WxC}}())
702 }
703 `)
704
705 var bitWiseUintTemplate = shapedTemplateOf(uintShapes, "bitwise uint complement", `
706 // Not returns the bitwise complement of x.
707 //
708 // Emulated, CPU Feature: {{.CPUfeature}}
709 func (x {{.VType}}) Not() {{.VType}} {
710 return x.Xor(x.Equal(x).ToInt{{.WxC}}().As{{.VType}}())
711 }
712 `)
713
714
715
716
717
718
719 func (t templateData) CPUfeatureAVX2if8() string {
720 if t.EWidth == 8 {
721 return "AVX2"
722 }
723 return t.CPUfeature()
724 }
725
726 var avx2UnsignedComparisonsTemplate = shapedTemplateOf(avx2UnsignedComparisons, "avx2 unsigned comparisons", `
727 // Greater returns a mask whose elements indicate whether x > y.
728 //
729 // Emulated, CPU Feature: {{.CPUfeatureAVX2if8}}
730 func (x {{.VType}}) Greater(y {{.VType}}) Mask{{.WxC}} {
731 a, b := x.AsInt{{.WxC}}(), y.AsInt{{.WxC}}()
732 {{- if eq .EWidth 8}}
733 signs := BroadcastInt{{.WxC}}(-1 << ({{.EWidth}}-1))
734 {{- else}}
735 ones := x.Equal(x).ToInt{{.WxC}}()
736 signs := ones.ShiftAllLeft({{.EWidth}}-1)
737 {{- end }}
738 return a.Xor(signs).Greater(b.Xor(signs))
739 }
740
741 // Less returns a mask whose elements indicate whether x < y.
742 //
743 // Emulated, CPU Feature: {{.CPUfeatureAVX2if8}}
744 func (x {{.VType}}) Less(y {{.VType}}) Mask{{.WxC}} {
745 a, b := x.AsInt{{.WxC}}(), y.AsInt{{.WxC}}()
746 {{- if eq .EWidth 8}}
747 signs := BroadcastInt{{.WxC}}(-1 << ({{.EWidth}}-1))
748 {{- else}}
749 ones := x.Equal(x).ToInt{{.WxC}}()
750 signs := ones.ShiftAllLeft({{.EWidth}}-1)
751 {{- end }}
752 return b.Xor(signs).Greater(a.Xor(signs))
753 }
754
755 // GreaterEqual returns a mask whose elements indicate whether x >= y.
756 //
757 // Emulated, CPU Feature: {{.CPUfeatureAVX2if8}}
758 func (x {{.VType}}) GreaterEqual(y {{.VType}}) Mask{{.WxC}} {
759 a, b := x.AsInt{{.WxC}}(), y.AsInt{{.WxC}}()
760 ones := x.Equal(x).ToInt{{.WxC}}()
761 {{- if eq .EWidth 8}}
762 signs := BroadcastInt{{.WxC}}(-1 << ({{.EWidth}}-1))
763 {{- else}}
764 signs := ones.ShiftAllLeft({{.EWidth}}-1)
765 {{- end }}
766 return b.Xor(signs).Greater(a.Xor(signs)).ToInt{{.WxC}}().Xor(ones).asMask()
767 }
768
769 // LessEqual returns a mask whose elements indicate whether x <= y.
770 //
771 // Emulated, CPU Feature: {{.CPUfeatureAVX2if8}}
772 func (x {{.VType}}) LessEqual(y {{.VType}}) Mask{{.WxC}} {
773 a, b := x.AsInt{{.WxC}}(), y.AsInt{{.WxC}}()
774 ones := x.Equal(x).ToInt{{.WxC}}()
775 {{- if eq .EWidth 8}}
776 signs := BroadcastInt{{.WxC}}(-1 << ({{.EWidth}}-1))
777 {{- else}}
778 signs := ones.ShiftAllLeft({{.EWidth}}-1)
779 {{- end }}
780 return a.Xor(signs).Greater(b.Xor(signs)).ToInt{{.WxC}}().Xor(ones).asMask()
781 }
782
783 // NotEqual returns a mask whose elements indicate whether x != y.
784 //
785 // Emulated, CPU Feature: {{.CPUfeature}}
786 func (x {{.VType}}) NotEqual(y {{.VType}}) Mask{{.WxC}} {
787 a, b := x.AsInt{{.WxC}}(), y.AsInt{{.WxC}}()
788 ones := x.Equal(x).ToInt{{.WxC}}()
789 return a.Equal(b).ToInt{{.WxC}}().Xor(ones).asMask()
790 }
791 `)
792
793 var unsafePATemplate = templateOf("unsafe PA helper", `
794 // pa{{.VType}} returns a type-unsafe pointer to array that can
795 // only be used with partial load/store operations that only
796 // access the known-safe portions of the array.
797 //
798 //go:nocheckptr
799 func pa{{.VType}}(s []{{.Etype}}) *[{{.Count}}]{{.Etype}} {
800 return (*[{{.Count}}]{{.Etype}})(unsafe.Pointer(&s[0]))
801 }
802 `)
803
804 var avx2MaskedTemplate = shapedTemplateOf(avx2Shapes, "avx2 .Masked methods", `
805 // Masked returns x but with elements zeroed where mask is false.
806 //
807 // Emulated, CPU Feature: {{.CPUfeature}}
808 func (x {{.VType}}) Masked(mask Mask{{.WxC}}) {{.VType}} {
809 im := mask.ToInt{{.WxC}}()
810 {{- if eq .Base "Int" }}
811 return im.And(x)
812 {{- else}}
813 return x.AsInt{{.WxC}}().And(im).As{{.VType}}()
814 {{- end -}}
815 }
816
817 // Merge returns x but with elements set to y where mask is false.
818 //
819 // Emulated, CPU Feature: {{.CPUfeature}}
820 func (x {{.VType}}) Merge(y {{.VType}}, mask Mask{{.WxC}}) {{.VType}} {
821 {{- if eq .BxC .WxC -}}
822 im := mask.ToInt{{.BxC}}()
823 {{- else}}
824 im := mask.ToInt{{.WxC}}().AsInt{{.BxC}}()
825 {{- end -}}
826 {{- if and (eq .Base "Int") (eq .BxC .WxC) }}
827 return y.blend(x, im)
828 {{- else}}
829 ix := x.AsInt{{.BxC}}()
830 iy := y.AsInt{{.BxC}}()
831 return iy.blend(ix, im).As{{.VType}}()
832 {{- end -}}
833 }
834 `)
835
836
837 var avx512MaskedTemplate = shapedTemplateOf(avx512Shapes, "avx512 .Masked methods", `
838 // Masked returns x but with elements zeroed where mask is false.
839 //
840 // Emulated, CPU Feature: AVX512
841 func (x {{.VType}}) Masked(mask Mask{{.WxC}}) {{.VType}} {
842 im := mask.ToInt{{.WxC}}()
843 {{- if eq .Base "Int" }}
844 return im.And(x)
845 {{- else}}
846 return x.AsInt{{.WxC}}().And(im).As{{.VType}}()
847 {{- end -}}
848 }
849
850 // Merge returns x but with elements set to y where mask is false.
851 //
852 // Emulated, CPU Feature: AVX512
853 func (x {{.VType}}) Merge(y {{.VType}}, mask Mask{{.WxC}}) {{.VType}} {
854 {{- if eq .Base "Int" }}
855 return y.blendMasked(x, mask)
856 {{- else}}
857 ix := x.AsInt{{.WxC}}()
858 iy := y.AsInt{{.WxC}}()
859 return iy.blendMasked(ix, mask).As{{.VType}}()
860 {{- end -}}
861 }
862 `)
863
864 func (t templateData) CPUfeatureBC() string {
865 switch t.Vwidth {
866 case 128:
867 return "AVX2"
868 case 256:
869 return "AVX2"
870 case 512:
871 if t.EWidth <= 16 {
872 return "AVX512BW"
873 }
874 return "AVX512F"
875 }
876 panic(fmt.Errorf("unexpected vector width %d", t.Vwidth))
877 }
878
879 var broadcastTemplate = templateOf("Broadcast functions", `
880 // Broadcast{{.VType}} returns a vector with the input
881 // x assigned to all elements of the output.
882 //
883 // Emulated, CPU Feature: {{.CPUfeatureBC}}
884 func Broadcast{{.VType}}(x {{.Etype}}) {{.VType}} {
885 var z {{.As128BitVec }}
886 return z.SetElem(0, x).Broadcast1To{{.Count}}()
887 }
888 `)
889
890 var maskCvtTemplate = shapedTemplateOf(intShapes, "Mask conversions", `
891 // ToMask converts from {{.Base}}{{.WxC}} to Mask{{.WxC}}, mask element is set to true when the corresponding vector element is non-zero.
892 func (from {{.Base}}{{.WxC}}) ToMask() (to Mask{{.WxC}}) {
893 return from.NotEqual({{.Base}}{{.WxC}}{})
894 }
895 `)
896
897 var stringTemplate = shapedTemplateOf(allShapes, "String methods", `
898 // String returns a string representation of SIMD vector x.
899 func (x {{.VType}}) String() string {
900 var s [{{.Count}}]{{.Etype}}
901 x.Store(&s)
902 return sliceToString(s[:])
903 }
904 `)
905
906 const SIMD = "../../"
907 const TD = "../../internal/simd_test/"
908 const SSA = "../../../../cmd/compile/internal/ssa/"
909
910 func main() {
911 sl := flag.String("sl", SIMD+"slice_gen_amd64.go", "file name for slice operations")
912 cm := flag.String("cm", SIMD+"compare_gen_amd64.go", "file name for comparison operations")
913 mm := flag.String("mm", SIMD+"maskmerge_gen_amd64.go", "file name for mask/merge operations")
914 op := flag.String("op", SIMD+"other_gen_amd64.go", "file name for other operations")
915 ush := flag.String("ush", SIMD+"unsafe_helpers.go", "file name for unsafe helpers")
916 bh := flag.String("bh", TD+"binary_helpers_test.go", "file name for binary test helpers")
917 uh := flag.String("uh", TD+"unary_helpers_test.go", "file name for unary test helpers")
918 th := flag.String("th", TD+"ternary_helpers_test.go", "file name for ternary test helpers")
919 ch := flag.String("ch", TD+"compare_helpers_test.go", "file name for compare test helpers")
920 cmh := flag.String("cmh", TD+"comparemasked_helpers_test.go", "file name for compare-masked test helpers")
921 flag.Parse()
922
923 if *sl != "" {
924 one(*sl, unsafePrologue,
925 sliceTemplate,
926 avx512MaskedLoadSlicePartTemplate,
927 avx2MaskedLoadSlicePartTemplate,
928 avx2SmallLoadSlicePartTemplate,
929 )
930 }
931 if *cm != "" {
932 one(*cm, prologue,
933 avx2SignedComparisonsTemplate,
934 avx2UnsignedComparisonsTemplate,
935 )
936 }
937 if *mm != "" {
938 one(*mm, prologue,
939 avx2MaskedTemplate,
940 avx512MaskedTemplate,
941 )
942 }
943 if *op != "" {
944 one(*op, prologue,
945 broadcastTemplate,
946 maskCvtTemplate,
947 bitWiseIntTemplate,
948 bitWiseUintTemplate,
949 stringTemplate,
950 )
951 }
952 if *ush != "" {
953 one(*ush, unsafePrologue, unsafePATemplate)
954 }
955 if *uh != "" {
956 one(*uh, curryTestPrologue("unary simd methods"), unaryTemplate,
957 unaryToInt8, unaryToUint8, unaryToInt16, unaryToUint16,
958 unaryToInt32, unaryToUint32, unaryToInt64, unaryToUint64,
959 unaryToFloat32, unaryToFloat64,
960 unaryToInt64x2, unaryToInt64x4,
961 unaryToUint64x2, unaryToUint64x4,
962 unaryToInt32x4, unaryToInt32x8,
963 unaryToUint32x4, unaryToUint32x8,
964 unaryToInt16x8, unaryToUint16x8,
965 unaryFlakyTemplate,
966 )
967 }
968 if *bh != "" {
969 one(*bh, curryTestPrologue("binary simd methods"), binaryTemplate)
970 }
971 if *th != "" {
972 one(*th, curryTestPrologue("ternary simd methods"), ternaryTemplate, ternaryFlakyTemplate)
973 }
974 if *ch != "" {
975 one(*ch, curryTestPrologue("simd methods that compare two operands"), compareTemplate, compareUnaryTemplate)
976 }
977 if *cmh != "" {
978 one(*cmh, curryTestPrologue("simd methods that compare two operands under a mask"), compareMaskedTemplate)
979 }
980
981 nonTemplateRewrites(SSA+"tern_helpers.go", ssaPrologue, classifyBooleanSIMD, ternOpForLogical)
982
983 }
984
985 func ternOpForLogical(out io.Writer) {
986 fmt.Fprintf(out, `
987 func ternOpForLogical(op Op) Op {
988 switch op {
989 `)
990
991 intShapes.forAllShapes(func(seq int, t, upperT string, w, c int, out io.Writer) {
992 wt, ct := w, c
993 if wt < 32 {
994 wt = 32
995 ct = (w * c) / wt
996 }
997 fmt.Fprintf(out, "case OpAndInt%[1]dx%[2]d, OpOrInt%[1]dx%[2]d, OpXorInt%[1]dx%[2]d,OpAndNotInt%[1]dx%[2]d: return OpternInt%dx%d\n", w, c, wt, ct)
998 fmt.Fprintf(out, "case OpAndUint%[1]dx%[2]d, OpOrUint%[1]dx%[2]d, OpXorUint%[1]dx%[2]d,OpAndNotUint%[1]dx%[2]d: return OpternUint%dx%d\n", w, c, wt, ct)
999 }, out)
1000
1001 fmt.Fprintf(out, `
1002 }
1003 return op
1004 }
1005 `)
1006
1007 }
1008
1009 func classifyBooleanSIMD(out io.Writer) {
1010 fmt.Fprintf(out, `
1011 type SIMDLogicalOP uint8
1012 const (
1013 // boolean simd operations, for reducing expression to VPTERNLOG* instructions
1014 // sloInterior is set for non-root nodes in logical-op expression trees.
1015 // the operations are even-numbered.
1016 sloInterior SIMDLogicalOP = 1
1017 sloNone SIMDLogicalOP = 2 * iota
1018 sloAnd
1019 sloOr
1020 sloAndNot
1021 sloXor
1022 sloNot
1023 )
1024 func classifyBooleanSIMD(v *Value) SIMDLogicalOP {
1025 switch v.Op {
1026 case `)
1027 intShapes.forAllShapes(func(seq int, t, upperT string, w, c int, out io.Writer) {
1028 op := "And"
1029 if seq > 0 {
1030 fmt.Fprintf(out, ",Op%s%s%dx%d", op, upperT, w, c)
1031 } else {
1032 fmt.Fprintf(out, "Op%s%s%dx%d", op, upperT, w, c)
1033 }
1034 seq++
1035 }, out)
1036
1037 fmt.Fprintf(out, `:
1038 return sloAnd
1039
1040 case `)
1041 intShapes.forAllShapes(func(seq int, t, upperT string, w, c int, out io.Writer) {
1042 op := "Or"
1043 if seq > 0 {
1044 fmt.Fprintf(out, ",Op%s%s%dx%d", op, upperT, w, c)
1045 } else {
1046 fmt.Fprintf(out, "Op%s%s%dx%d", op, upperT, w, c)
1047 }
1048 seq++
1049 }, out)
1050
1051 fmt.Fprintf(out, `:
1052 return sloOr
1053
1054 case `)
1055 intShapes.forAllShapes(func(seq int, t, upperT string, w, c int, out io.Writer) {
1056 op := "AndNot"
1057 if seq > 0 {
1058 fmt.Fprintf(out, ",Op%s%s%dx%d", op, upperT, w, c)
1059 } else {
1060 fmt.Fprintf(out, "Op%s%s%dx%d", op, upperT, w, c)
1061 }
1062 seq++
1063 }, out)
1064
1065 fmt.Fprintf(out, `:
1066 return sloAndNot
1067 `)
1068
1069
1070
1071
1072
1073 intShapes.forAllShapes(
1074 func(seq int, t, upperT string, w, c int, out io.Writer) {
1075 fmt.Fprintf(out, "case OpXor%s%dx%d: ", upperT, w, c)
1076 fmt.Fprintf(out, `
1077 if y := v.Args[1]; y.Op == OpEqual%s%dx%d &&
1078 y.Args[0] == y.Args[1] {
1079 return sloNot
1080 }
1081 `, upperT, w, c)
1082 fmt.Fprintf(out, "return sloXor\n")
1083 }, out)
1084
1085 fmt.Fprintf(out, `
1086 }
1087 return sloNone
1088 }
1089 `)
1090 }
1091
1092
1093
1094 func numberLines(data []byte) string {
1095 var buf bytes.Buffer
1096 r := bytes.NewReader(data)
1097 s := bufio.NewScanner(r)
1098 for i := 1; s.Scan(); i++ {
1099 fmt.Fprintf(&buf, "%d: %s\n", i, s.Text())
1100 }
1101 return buf.String()
1102 }
1103
1104 func nonTemplateRewrites(filename string, prologue func(s string, out io.Writer), rewrites ...func(out io.Writer)) {
1105 if filename == "" {
1106 return
1107 }
1108
1109 ofile := os.Stdout
1110
1111 if filename != "-" {
1112 var err error
1113 ofile, err = os.Create(filename)
1114 if err != nil {
1115 fmt.Fprintf(os.Stderr, "Could not create the output file %s for the generated code, %v", filename, err)
1116 os.Exit(1)
1117 }
1118 }
1119
1120 out := new(bytes.Buffer)
1121
1122 prologue("tmplgen", out)
1123 for _, rewrite := range rewrites {
1124 rewrite(out)
1125 }
1126
1127 b, err := format.Source(out.Bytes())
1128 if err != nil {
1129 fmt.Fprintf(os.Stderr, "There was a problem formatting the generated code for %s, %v\n", filename, err)
1130 fmt.Fprintf(os.Stderr, "%s\n", numberLines(out.Bytes()))
1131 fmt.Fprintf(os.Stderr, "There was a problem formatting the generated code for %s, %v\n", filename, err)
1132 os.Exit(1)
1133 } else {
1134 ofile.Write(b)
1135 ofile.Close()
1136 }
1137
1138 }
1139
1140 func one(filename string, prologue func(s string, out io.Writer), sats ...shapeAndTemplate) {
1141 if filename == "" {
1142 return
1143 }
1144
1145 ofile := os.Stdout
1146
1147 if filename != "-" {
1148 var err error
1149 ofile, err = os.Create(filename)
1150 if err != nil {
1151 fmt.Fprintf(os.Stderr, "Could not create the output file %s for the generated code, %v", filename, err)
1152 os.Exit(1)
1153 }
1154 }
1155
1156 out := new(bytes.Buffer)
1157
1158 prologue("tmplgen", out)
1159 for _, sat := range sats {
1160 sat.forTemplates(out)
1161 }
1162
1163 b, err := format.Source(out.Bytes())
1164 if err != nil {
1165 fmt.Fprintf(os.Stderr, "There was a problem formatting the generated code for %s, %v\n", filename, err)
1166 fmt.Fprintf(os.Stderr, "%s\n", numberLines(out.Bytes()))
1167 fmt.Fprintf(os.Stderr, "There was a problem formatting the generated code for %s, %v\n", filename, err)
1168 os.Exit(1)
1169 } else {
1170 ofile.Write(b)
1171 ofile.Close()
1172 }
1173
1174 }
1175
View as plain text