Can I create a function that must only be used with defer?

Issue

For example:

package package

// Dear user, CleanUp must only be used with defer: defer CleanUp()
func CleanUp() {
    // some logic to check if call was deferred
    // do tear down
}

And in userland code:

func main() {
    package.CleanUp() // PANIC, CleanUp must be deferred!
}

But all should be fine if user runs:

func main() {
   defer package.CleanUp() // good job, no panic
}

Things I already tried:

func DeferCleanUp() {
    defer func() { /* do tear down */ }()
    // But then I realized this was exactly the opposite of what I needed
    // user doesn't need to call defer CleanUp anymore but...
}
// now if the APi is misused it can cause problems too:
defer DeferCleanUp() // a defer inception xD, question remains.

Solution

Alright, per OPs request and just for laughs, I’m posting this hacky approach to solving this by looking at the call stack and applying some heuristics.

DISCLAIMER: Do not use this in real code. I don’t think checking deferred is even a good thing.

Also Note: this approach will only work if the executable and the source are on the same machine.

Link to gist: https://gist.github.com/dvirsky/dfdfd4066c70e8391dc5 (this doesn’t work in the playground because you can’t read the source file there)

package main

import(
    "fmt"
    "runtime"
    "io/ioutil"
    "bytes"
    "strings"
)




func isDeferred() bool {
    
    // Let's get the caller's name first
    var caller string
    if fn, _, _, ok  := runtime.Caller(1); ok {
        caller = function(fn)
    } else {
        panic("No caller")
    }
    
    // Let's peek 2 levels above this - the first level is this function,
    // The second is CleanUp()
    // The one we want is who called CleanUp()
    if _, file, line, ok  := runtime.Caller(2); ok {
        
        // now we actually need to read the source file
        // This should be cached of course to avoid terrible performance
        // I copied this from runtime/debug, so it's a legitimate thing to do :)
        data, err := ioutil.ReadFile(file)
        if err != nil {
            panic("Could not read file")
        }
        
        // now let's read the exact line of the caller 
        lines := bytes.Split(data, []byte{'\n'})
        lineText := strings.TrimSpace(string(lines[line-1]))
        fmt.Printf("Line text: '%s'\n", lineText)
        
        
        // Now let's apply some ugly rules of thumb. This is the fragile part
        // It can be improved with regex or actual AST parsing, but dude...
        return lineText == "}" ||  // on simple defer this is what we get
               !strings.Contains(lineText, caller)  || // this handles the case of defer func() { CleanUp() }()
               strings.Contains(lineText, "defer ")
        
        
    } // not ok - means we were not clled from at least 3 levels deep
    
    return false
}

func CleanUp() {
    if !isDeferred() {
        panic("Not Deferred!")
    }
    
    
}

// This should not panic
func fine() {
    defer CleanUp() 
    
    fmt.Println("Fine!")
}


// this should not panic as well
func alsoFine() {
    defer func() { CleanUp() }()
    
    fmt.Println("Also Fine!")
}

// this should panic
func notFine() {
    CleanUp() 
    
    fmt.Println("Not Fine!")
}

// Taken from the std lib's runtime/debug:
// function returns, if possible, the name of the function containing the PC.
func function(pc uintptr) string {
    fn := runtime.FuncForPC(pc)
    if fn == nil {
        return ""
    }
    name := fn.Name()
    if lastslash := strings.LastIndex(name, "/"); lastslash >= 0 {
        name = name[lastslash+1:]
    }
    if period := strings.Index(name, "."); period >= 0 {
        name = name[period+1:]
    }
    name = strings.Replace(name, "ยท", ".", -1)
    return name
}

func main(){
    fine()
    alsoFine()
    notFine()
}

Answered By – Not_a_Golfer

Answer Checked By – Mildred Charles (GoLangFix Admin)

Leave a Reply

Your email address will not be published.