Better err msgs+better detection of coroutines

This commit is contained in:
bloeys
2022-11-04 02:54:10 +04:00
parent ac5adbf8a3
commit 1d4451dac2

54
main.go
View File

@ -84,10 +84,6 @@ func genHasGenChecksOnOriginalFuncs(cwd string) {
panic(err) panic(err)
} }
// if len(pkgs) != 1 {
// panic(fmt.Sprintf("expected to find one package but found %d", len(pkgs)))
// }
for _, pkg := range pkgs { for _, pkg := range pkgs {
p := &processor{ p := &processor{
@ -128,10 +124,20 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
} }
coroutineParamName := getCoroutineParamNameFromFuncDecl(funcDecl) coroutineParamName := getCoroutineParamNameFromFuncDecl(funcDecl)
if coroutineParamName == "" || !funcDeclCallsCoroutineBegin(funcDecl, coroutineParamName) { if coroutineParamName == "" {
return false return false
} }
hasBegin := funcHasSelInBody(funcDecl, coroutineParamName, "Begin")
hasYield := funcHasSelInBody(funcDecl, coroutineParamName, "Yield")
if !hasBegin && !hasYield {
return false
}
if hasYield && !hasBegin {
panic(fmt.Sprintf("Function '%s' in file '%s' has a 'Yield()' call but no 'Begin()'. Please ensure your coroutines have 'Begin()'", funcDecl.Name.Name, p.fset.File(funcDecl.Pos()).Name()))
}
beginBodyListIndex := -1 beginBodyListIndex := -1
lastCaseEndBodyListIndex := -1 lastCaseEndBodyListIndex := -1
switchStmt := &ast.SwitchStmt{ switchStmt := &ast.SwitchStmt{
@ -255,12 +261,10 @@ func (p *processor) genCogoFuncsNodeProcessor(c *astutil.Cursor) bool {
continue continue
} }
if !funcCallHasLhsName(cogoFuncSelExpr, coroutineParamName) { if !selExprHasLhsName(cogoFuncSelExpr, coroutineParamName) {
continue continue
} }
// cogoFuncCallLineNum := p.fset.File(cogoFuncSelExpr.Pos()).Line(cogoFuncSelExpr.Pos())
// Now that we found a call to cogo decide what to do // Now that we found a call to cogo decide what to do
if cogoFuncSelExpr.Sel.Name == "Begin" { if cogoFuncSelExpr.Sel.Name == "Begin" {
@ -531,7 +535,7 @@ func (p *processor) genHasGenChecksOnOriginalFuncsNodeProcessor(c *astutil.Curso
} }
coroutineParamName := getCoroutineParamNameFromFuncDecl(funcDecl) coroutineParamName := getCoroutineParamNameFromFuncDecl(funcDecl)
if coroutineParamName == "" || !funcDeclCallsCoroutineBegin(funcDecl, coroutineParamName) { if coroutineParamName == "" || !funcHasSelInBody(funcDecl, coroutineParamName, "Begin") {
return false return false
} }
@ -599,7 +603,7 @@ func createStmtFromSelFuncCall(lhs, rhs string) ast.Stmt {
} }
} }
func funcDeclCallsCoroutineBegin(fd *ast.FuncDecl, coroutineParamName string) bool { func funcHasSelInBody(fd *ast.FuncDecl, selLhs, selRhs string) bool {
if fd.Body == nil || len(fd.Body.List) == 0 { if fd.Body == nil || len(fd.Body.List) == 0 {
return false return false
@ -618,12 +622,12 @@ func funcDeclCallsCoroutineBegin(fd *ast.FuncDecl, coroutineParamName string) bo
continue continue
} }
pkgFuncCallExpr, ok := callExpr.Fun.(*ast.SelectorExpr) selExpr, ok := callExpr.Fun.(*ast.SelectorExpr)
if !ok { if !ok {
continue continue
} }
if funcCallHasLhsName(pkgFuncCallExpr, coroutineParamName) { if selExprIs(selExpr, selLhs, selRhs) {
return true return true
} }
} }
@ -633,6 +637,11 @@ func funcDeclCallsCoroutineBegin(fd *ast.FuncDecl, coroutineParamName string) bo
func getCoroutineParamNameFromFuncDecl(fd *ast.FuncDecl) string { func getCoroutineParamNameFromFuncDecl(fd *ast.FuncDecl) string {
// If func doesn't take one parameter then it is not a coroutine
if len(fd.Type.Params.List) == 0 {
return ""
}
for _, p := range fd.Type.Params.List { for _, p := range fd.Type.Params.List {
ptrExpr, ok := p.Type.(*ast.StarExpr) ptrExpr, ok := p.Type.(*ast.StarExpr)
@ -661,11 +670,7 @@ func getCoroutineParamNameFromFuncDecl(fd *ast.FuncDecl) string {
return "" return ""
} }
// func getIdentNameFromExprOrPanic(e ast.Expr) string { func selExprHasLhsName(selExpr *ast.SelectorExpr, pkgName string) bool {
// return e.(*ast.Ident).Name
// }
func funcCallHasLhsName(selExpr *ast.SelectorExpr, pkgName string) bool {
pkgIdent, ok := selExpr.X.(*ast.Ident) pkgIdent, ok := selExpr.X.(*ast.Ident)
return ok && pkgIdent.Name == pkgName return ok && pkgIdent.Name == pkgName
} }
@ -680,21 +685,6 @@ func selExprIs(selExpr *ast.SelectorExpr, lhs, rhs string) bool {
return pkgIdentExpr.Name == lhs && selExpr.Sel.Name == rhs return pkgIdentExpr.Name == lhs && selExpr.Sel.Name == rhs
} }
// func filter[T any](arr []T, where func(x T) bool) []T {
// out := []T{}
// for i := 0; i < len(arr); i++ {
// if !where(arr[i]) {
// continue
// }
// out = append(out, arr[i])
// }
// return out
// }
func getCaseWithStmts(caseConditions []ast.Expr, stmts []ast.Stmt) *ast.CaseClause { func getCaseWithStmts(caseConditions []ast.Expr, stmts []ast.Stmt) *ast.CaseClause {
return &ast.CaseClause{ return &ast.CaseClause{
List: caseConditions, List: caseConditions,