Source file src/cmd/compile/internal/bloop/bloop.go

     1  // Copyright 2025 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package bloop
     6  
     7  // This file contains support routines for keeping
     8  // statements alive
     9  // in such loops (example):
    10  //
    11  //	for b.Loop() {
    12  //		var a, b int
    13  //		a = 5
    14  //		b = 6
    15  //		f(a, b)
    16  //	}
    17  //
    18  // The results of a, b and f(a, b) will be kept alive.
    19  //
    20  // Formally, the lhs (if they are [ir.Name]-s) of
    21  // [ir.AssignStmt], [ir.AssignListStmt],
    22  // [ir.AssignOpStmt], and the results of [ir.CallExpr]
    23  // or its args if it doesn't return a value will be kept
    24  // alive.
    25  //
    26  // The keep alive logic is implemented with as wrapping a
    27  // runtime.KeepAlive around the Name.
    28  //
    29  // TODO: currently this is implemented with KeepAlive
    30  // because it will prevent DSE and DCE which is probably
    31  // what we want right now. And KeepAlive takes an ssa
    32  // value instead of a symbol, which is easier to manage.
    33  // But since KeepAlive's context was mainly in the runtime
    34  // and GC, should we implement a new intrinsic that lowers
    35  // to OpVarLive? Peeling out the symbols is a bit tricky
    36  // and also VarLive seems to assume that there exists a
    37  // VarDef on the same symbol that dominates it.
    38  
    39  import (
    40  	"cmd/compile/internal/base"
    41  	"cmd/compile/internal/ir"
    42  	"cmd/compile/internal/typecheck"
    43  	"cmd/compile/internal/types"
    44  	"cmd/internal/src"
    45  )
    46  
    47  // getNameFromNode tries to iteratively peel down the node to
    48  // get the name.
    49  func getNameFromNode(n ir.Node) *ir.Name {
    50  	// Tries to iteratively peel down the node to get the names.
    51  	for n != nil {
    52  		switch n.Op() {
    53  		case ir.ONAME:
    54  			// Found the name, stop the loop.
    55  			return n.(*ir.Name)
    56  		case ir.OSLICE, ir.OSLICE3:
    57  			n = n.(*ir.SliceExpr).X
    58  		case ir.ODOT:
    59  			n = n.(*ir.SelectorExpr).X
    60  		case ir.OCONV, ir.OCONVIFACE, ir.OCONVNOP:
    61  			n = n.(*ir.ConvExpr).X
    62  		case ir.OADDR:
    63  			n = n.(*ir.AddrExpr).X
    64  		case ir.ODOTPTR:
    65  			n = n.(*ir.SelectorExpr).X
    66  		case ir.OINDEX, ir.OINDEXMAP:
    67  			n = n.(*ir.IndexExpr).X
    68  		default:
    69  			n = nil
    70  		}
    71  	}
    72  	return nil
    73  }
    74  
    75  // getAddressableNameFromNode is like getNameFromNode but returns nil if the node is not addressable.
    76  func getAddressableNameFromNode(n ir.Node) *ir.Name {
    77  	if name := getNameFromNode(n); name != nil && ir.IsAddressable(name) {
    78  		return name
    79  	}
    80  	return nil
    81  }
    82  
    83  // getKeepAliveNodes analyzes an IR node and returns a list of nodes that must be kept alive.
    84  func getKeepAliveNodes(pos src.XPos, n ir.Node) ir.Nodes {
    85  	name := getAddressableNameFromNode(n)
    86  	if name != nil {
    87  		debugName(name, pos)
    88  		return ir.Nodes{name}
    89  	} else if deref, ok := n.(*ir.StarExpr); ok && deref != nil {
    90  		if base.Flag.LowerM > 1 {
    91  			base.WarnfAt(pos, "dereference will be kept alive")
    92  		}
    93  		return ir.Nodes{deref}
    94  	} else if base.Flag.LowerM > 1 {
    95  		base.WarnfAt(pos, "expr is unknown to bloop pass")
    96  	}
    97  	return nil
    98  }
    99  
   100  // keepAliveAt returns a statement that is either curNode, or a
   101  // block containing curNode followed by a call to runtime.KeepAlive for each
   102  // node in ns. These calls ensure that nodes in ns will be live until
   103  // after curNode's execution.
   104  func keepAliveAt(ns ir.Nodes, curNode ir.Node) ir.Node {
   105  	if len(ns) == 0 {
   106  		return curNode
   107  	}
   108  
   109  	pos := curNode.Pos()
   110  	calls := ir.Nodes{curNode}
   111  	for _, n := range ns {
   112  		if n == nil || n.Sym() == nil || n.Sym().IsBlank() {
   113  			continue
   114  		}
   115  		if !ir.IsAddressable(n) {
   116  			base.FatalfAt(n.Pos(), "keepAliveAt: node %v is not addressable", n)
   117  		}
   118  		arg := ir.NewConvExpr(pos, ir.OCONV, types.Types[types.TUNSAFEPTR], typecheck.NodAddr(n))
   119  		callExpr := typecheck.Call(pos, typecheck.LookupRuntime("KeepAlive"), ir.Nodes{arg}, false).(*ir.CallExpr)
   120  		callExpr.IsCompilerVarLive = true
   121  		callExpr.NoInline = true
   122  		calls = append(calls, callExpr)
   123  	}
   124  
   125  	return ir.NewBlockStmt(pos, calls)
   126  }
   127  
   128  func debugName(name *ir.Name, pos src.XPos) {
   129  	if base.Flag.LowerM > 1 {
   130  		if name.Linksym() != nil {
   131  			base.WarnfAt(pos, "%s will be kept alive", name.Linksym().Name)
   132  		} else {
   133  			base.WarnfAt(pos, "expr will be kept alive")
   134  		}
   135  	}
   136  }
   137  
   138  // preserveCallResults assigns the results of a call statement to temporary variables to ensure they remain alive.
   139  func preserveCallResults(curFn *ir.Func, call *ir.CallExpr) ir.Node {
   140  	var ns ir.Nodes
   141  	lhs := make(ir.Nodes, call.Fun.Type().NumResults())
   142  	for i, res := range call.Fun.Type().Results() {
   143  		tmp := typecheck.TempAt(call.Pos(), curFn, res.Type)
   144  		lhs[i] = tmp
   145  		ns = append(ns, tmp)
   146  	}
   147  
   148  	if base.Flag.LowerM > 1 {
   149  		plural := ""
   150  		if call.Fun.Type().NumResults() > 1 {
   151  			plural = "s"
   152  		}
   153  		base.WarnfAt(call.Pos(), "function result%s will be kept alive", plural)
   154  	}
   155  
   156  	assign := typecheck.AssignExpr(ir.NewAssignListStmt(call.Pos(), ir.OAS2, lhs, ir.Nodes{call})).(*ir.AssignListStmt)
   157  	assign.Def = true
   158  	for _, tmp := range lhs {
   159  		// Place temp declarations in the loop body to help escape analysis.
   160  		assign.PtrInit().Append(typecheck.Stmt(ir.NewDecl(assign.Pos(), ir.ODCL, tmp.(*ir.Name))))
   161  	}
   162  	return keepAliveAt(ns, assign)
   163  }
   164  
   165  // preserveCallArgs ensures the arguments of a call statement are kept alive by transforming them into temporaries if necessary.
   166  func preserveCallArgs(curFn *ir.Func, call *ir.CallExpr) ir.Node {
   167  	var argTmps ir.Nodes
   168  	var names ir.Nodes
   169  	preserveTmp := func(pos src.XPos, n ir.Node) ir.Node {
   170  		tmp := typecheck.TempAt(pos, curFn, n.Type())
   171  		assign := ir.NewAssignStmt(pos, tmp, n)
   172  		assign.Def = true
   173  		// Place temp declarations in the loop body to help escape analysis.
   174  		assign.PtrInit().Append(typecheck.Stmt(ir.NewDecl(assign.Pos(), ir.ODCL, tmp)))
   175  		argTmps = append(argTmps, typecheck.AssignExpr(assign))
   176  		names = append(names, tmp)
   177  		if base.Flag.LowerM > 1 {
   178  			base.WarnfAt(call.Pos(), "function arg will be kept alive")
   179  		}
   180  		return tmp
   181  	}
   182  	for i, a := range call.Args {
   183  		if name := getAddressableNameFromNode(a); name != nil {
   184  			// If they are name, keep them alive directly.
   185  			debugName(name, call.Pos())
   186  			names = append(names, name)
   187  		} else if a.Op() == ir.OSLICELIT {
   188  			// variadic args are encoded as slice literal.
   189  			s := a.(*ir.CompLitExpr)
   190  			var ns ir.Nodes
   191  			for i, elem := range s.List {
   192  				if name := getAddressableNameFromNode(elem); name != nil {
   193  					debugName(name, call.Pos())
   194  					ns = append(ns, name)
   195  				} else {
   196  					// We need a temporary to save this arg.
   197  					s.List[i] = preserveTmp(elem.Pos(), elem)
   198  				}
   199  			}
   200  			names = append(names, ns...)
   201  		} else {
   202  			// expressions, we need to assign them to temps and change the original arg to reference them.
   203  			call.Args[i] = preserveTmp(call.Pos(), a)
   204  		}
   205  	}
   206  	if len(argTmps) > 0 {
   207  		argTmps = append(argTmps, call)
   208  		return keepAliveAt(names, ir.NewBlockStmt(call.Pos(), argTmps))
   209  	}
   210  	return keepAliveAt(names, call)
   211  }
   212  
   213  // preserveStmt transforms stmt so that any names defined/assigned within it
   214  // are used after stmt's execution, preventing their dead code elimination
   215  // and dead store elimination. The return value is the transformed statement.
   216  func preserveStmt(curFn *ir.Func, stmt ir.Node) ir.Node {
   217  	switch n := stmt.(type) {
   218  	case *ir.AssignStmt:
   219  		// If the left hand side is blank, we need to assign it to a temp
   220  		// so that it can be kept alive.
   221  		if ir.IsBlank(n.X) {
   222  			tmp := typecheck.TempAt(n.Pos(), curFn, n.Y.Type())
   223  			n.X = tmp
   224  			n.Def = true
   225  			n.PtrInit().Append(typecheck.Stmt(ir.NewDecl(n.Pos(), ir.ODCL, tmp)))
   226  			stmt = typecheck.AssignExpr(n)
   227  			n = stmt.(*ir.AssignStmt)
   228  		}
   229  		return keepAliveAt(getKeepAliveNodes(n.Pos(), n.X), n)
   230  	case *ir.AssignListStmt:
   231  		var ns ir.Nodes
   232  		hasBlank := false
   233  		for i, lhs := range n.Lhs {
   234  			if ir.IsBlank(lhs) {
   235  				// If the left hand side has blanks, we need to assign them to temps
   236  				// so that they can be kept alive.
   237  				var typ *types.Type
   238  				// AssignListStmt can have tuple or a list of expressions on the right hand side.
   239  				if len(n.Rhs) == 1 && n.Rhs[0].Type() != nil &&
   240  					n.Rhs[0].Type().IsTuple() &&
   241  					len(n.Lhs) == n.Rhs[0].Type().NumFields() {
   242  					typ = n.Rhs[0].Type().Field(i).Type
   243  				} else if len(n.Rhs) == len(n.Lhs) {
   244  					typ = n.Rhs[i].Type()
   245  				} else {
   246  					// Unrecognized shapes, skip?
   247  					base.WarnfAt(n.Pos(), "unrecognized shape for assign list stmt for blank assignment")
   248  					continue
   249  				}
   250  				tmp := typecheck.TempAt(n.Pos(), curFn, typ)
   251  				n.Lhs[i] = tmp
   252  				n.PtrInit().Append(typecheck.Stmt(ir.NewDecl(n.Pos(), ir.ODCL, tmp)))
   253  				hasBlank = true
   254  			}
   255  			ns = append(ns, getKeepAliveNodes(n.Pos(), n.Lhs[i])...)
   256  		}
   257  		if hasBlank {
   258  			// blank nodes are rewritten to temps, we need to typecheck the node again.
   259  			n.Def = true
   260  			stmt = typecheck.AssignExpr(n)
   261  			n = stmt.(*ir.AssignListStmt)
   262  		}
   263  		return keepAliveAt(ns, n)
   264  	case *ir.AssignOpStmt:
   265  		return keepAliveAt(getKeepAliveNodes(n.Pos(), n.X), n)
   266  	case *ir.CallExpr:
   267  		// The function's results are not assigned, preserve them.
   268  		if n.Fun != nil && n.Fun.Type() != nil && n.Fun.Type().NumResults() != 0 {
   269  			return preserveCallResults(curFn, n)
   270  		}
   271  		// This function doesn't return anything, keep its args alive.
   272  		return preserveCallArgs(curFn, n)
   273  	}
   274  	return stmt
   275  }
   276  
   277  func preserveStmts(curFn *ir.Func, list ir.Nodes) {
   278  	for i := range list {
   279  		list[i] = preserveStmt(curFn, list[i])
   280  	}
   281  }
   282  
   283  // isTestingBLoop returns true if it matches the node as a
   284  // testing.(*B).Loop. See issue #61515.
   285  func isTestingBLoop(t ir.Node) bool {
   286  	if t.Op() != ir.OFOR {
   287  		return false
   288  	}
   289  	nFor, ok := t.(*ir.ForStmt)
   290  	if !ok || nFor.Cond == nil || nFor.Cond.Op() != ir.OCALLFUNC {
   291  		return false
   292  	}
   293  	n, ok := nFor.Cond.(*ir.CallExpr)
   294  	if !ok || n.Fun == nil || n.Fun.Op() != ir.OMETHEXPR {
   295  		return false
   296  	}
   297  	name := ir.MethodExprName(n.Fun)
   298  	if name == nil {
   299  		return false
   300  	}
   301  	if fSym := name.Sym(); fSym != nil && name.Class == ir.PFUNC && fSym.Pkg != nil &&
   302  		fSym.Name == "(*B).Loop" && fSym.Pkg.Path == "testing" {
   303  		// Attempting to match a function call to testing.(*B).Loop
   304  		return true
   305  	}
   306  	return false
   307  }
   308  
   309  type editor struct {
   310  	inBloop bool
   311  	curFn   *ir.Func
   312  }
   313  
   314  func (e editor) edit(n ir.Node) ir.Node {
   315  	e.inBloop = isTestingBLoop(n) || e.inBloop
   316  	// It's in bloop, mark the stmts with bodies.
   317  	ir.EditChildren(n, e.edit)
   318  	if e.inBloop {
   319  		switch n := n.(type) {
   320  		case *ir.ForStmt:
   321  			preserveStmts(e.curFn, n.Body)
   322  		case *ir.IfStmt:
   323  			preserveStmts(e.curFn, n.Body)
   324  			preserveStmts(e.curFn, n.Else)
   325  		case *ir.BlockStmt:
   326  			preserveStmts(e.curFn, n.List)
   327  		case *ir.CaseClause:
   328  			preserveStmts(e.curFn, n.List)
   329  			preserveStmts(e.curFn, n.Body)
   330  		case *ir.CommClause:
   331  			preserveStmts(e.curFn, n.Body)
   332  		case *ir.RangeStmt:
   333  			preserveStmts(e.curFn, n.Body)
   334  		}
   335  	}
   336  	return n
   337  }
   338  
   339  // Walk performs a walk on all functions in the package
   340  // if it imports testing and wrap the results of all qualified
   341  // statements in a runtime.KeepAlive intrinsic call. See package
   342  // doc for more details.
   343  //
   344  //	for b.Loop() {...}
   345  //
   346  // loop's body.
   347  func Walk(pkg *ir.Package) {
   348  	hasTesting := false
   349  	for _, i := range pkg.Imports {
   350  		if i.Path == "testing" {
   351  			hasTesting = true
   352  			break
   353  		}
   354  	}
   355  	if !hasTesting {
   356  		return
   357  	}
   358  	for _, fn := range pkg.Funcs {
   359  		e := editor{false, fn}
   360  		ir.EditChildren(fn, e.edit)
   361  		if ir.MatchAstDump(fn, "bloop") {
   362  			ir.AstDump(fn, "bloop, "+ir.FuncName(fn))
   363  		}
   364  	}
   365  
   366  }
   367  

View as plain text