How to do Unit Testing with gorm

Issue

I’m new in Go and unit test. In my project am using Go with gorm and connecting mysql database.

my queries is how to unit test my code:

My code is below(main.go):

package main

import (
    "encoding/json"
    "fmt"
    "net/http"
    "strconv"
    "time"

    "github.com/gorilla/mux"
    "github.com/jinzhu/gorm"
    _ "github.com/jinzhu/gorm/dialects/mysql"
)

type Jobs struct {
    JobID                  uint   `json: "jobId" gorm:"primary_key;auto_increment"`
    SourcePath             string `json: "sourcePath"`
    Priority               int64  `json: "priority"`
    InternalPriority       string `json: "internalPriority"`
    ExecutionEnvironmentID string `json: "executionEnvironmentID"`
}

type ExecutionEnvironment struct {
    ID                     uint      `json: "id" gorm:"primary_key;auto_increment"`
    ExecutionEnvironmentId string    `json: "executionEnvironmentID"`
    CloudProviderType      string    `json: "cloudProviderType"`
    InfrastructureType     string    `json: "infrastructureType"`
    CloudRegion            string    `json: "cloudRegion"`
    CreatedAt              time.Time `json: "createdAt"`
}

var db *gorm.DB

func initDB() {
    var err error
    dataSourceName := "root:@tcp(localhost:3306)/?parseTime=True"
    db, err = gorm.Open("mysql", dataSourceName)

    if err != nil {
        fmt.Println(err)
        panic("failed to connect database")
    }
    //db.Exec("CREATE DATABASE test")
    db.LogMode(true)
    db.Exec("USE test")
    db.AutoMigrate(&Jobs{}, &ExecutionEnvironment{})
}

func GetAllJobs(w http.ResponseWriter, r *http.Request) {
    w.Header().Set("Content-Type", "application/json")
    fmt.Println("Executing Get All Jobs function")

    var jobs []Jobs
    if err := db.Select("jobs.*, execution_environments.*").Joins("JOIN execution_environments on execution_environments.execution_environment_id = jobs.execution_environment_id").Find(&jobs).Error; err != nil {
        fmt.Println(err)
    }
    fmt.Println()
    if len(jobs) == 0 {
        json.NewEncoder(w).Encode("No data found")
    } else {
        json.NewEncoder(w).Encode(jobs)
    }
}

// create job
func createJob(w http.ResponseWriter, r *http.Request) {
    w.Header().Set("Content-Type", "application/json")
    fmt.Println("Executing Create Jobs function")
    var jobs Jobs
    json.NewDecoder(r.Body).Decode(&jobs)
    db.Create(&jobs)
    json.NewEncoder(w).Encode(jobs)
}

// get job by id
func GetJobById(w http.ResponseWriter, r *http.Request) {
    w.Header().Set("Content-Type", "application/json")
    params := mux.Vars(r)
    jobId := params["jobId"]

    //var job []Jobs
    //db.Preload("Items").First(&job, jobId)
    var jobs []Jobs
    var executionEnvironments []ExecutionEnvironment
    if err := db.Table("jobs").Select("jobs.*, execution_environments.*").Joins("JOIN execution_environments on execution_environments.execution_environment_id = jobs.execution_environment_id").Where("job_id =?", jobId).Find(&jobs).Scan(&executionEnvironments).Error; err != nil {
        fmt.Println(err)
    }

    if len(jobs) == 0 {
        json.NewEncoder(w).Encode("No data found")
    } else {
        json.NewEncoder(w).Encode(jobs)
    }
}

// Delete Job By Id
func DeleteJobById(w http.ResponseWriter, r *http.Request) {
    params := mux.Vars(r)
    jobId := params["jobId"]

    // check data
    var job []Jobs
    db.Table("jobs").Select("jobs.*").Where("job_id=?", jobId).Find(&job)
    if len(job) == 0 {
        json.NewEncoder(w).Encode("Invalid JobId")
    } else {

        id64, _ := strconv.ParseUint(jobId, 10, 64)
        idToDelete := uint(id64)

        db.Where("job_id = ?", idToDelete).Delete(&Jobs{})
        //db.Where("jobId = ?", idToDelete).Delete(&ExecutionEnvironment{})

        json.NewEncoder(w).Encode("Job deleted successfully")
        w.WriteHeader(http.StatusNoContent)
    }

}

// create Execution Environments
func createEnvironments(w http.ResponseWriter, r *http.Request) {
    w.Header().Set("Content-Type", "application/json")
    fmt.Println("Executing Create Execution Environments function")
    var executionEnvironments ExecutionEnvironment
    json.NewDecoder(r.Body).Decode(&executionEnvironments)
    db.Create(&executionEnvironments)
    json.NewEncoder(w).Encode(executionEnvironments)
}

// Get Job Cloud Region
func GetJobCloudRegion(w http.ResponseWriter, r *http.Request) {
    w.Header().Set("Content-Type", "application/json")
    fmt.Println("Executing Get Job Cloud Region function")

    params := mux.Vars(r)
    jobId := params["jobId"]

    //var jobs []Jobs
    var executionEnvironment []ExecutionEnvironment

    db.Table("jobs").Select("execution_environments.*").Joins("JOIN execution_environments on execution_environments.execution_environment_id = jobs.execution_environment_id").Where("jobs.job_id =?", jobId).Find(&executionEnvironment)

    var pUuid []string
    for _, uuid := range executionEnvironment {
        pUuid = append(pUuid, uuid.CloudRegion)
    }
    json.NewEncoder(w).Encode(pUuid)

}

func main() {
    // router
    router := mux.NewRouter()
    // Access URL
    router.HandleFunc("/GetAllJobs", GetAllJobs).Methods("GET")
    router.HandleFunc("/createJob", createJob).Methods("POST")
    router.HandleFunc("/GetJobById/{jobId}", GetJobById).Methods("GET")
    router.HandleFunc("/DeleteJobById/{jobId}", DeleteJobById).Methods("DELETE")

    router.HandleFunc("/createEnvironments", createEnvironments).Methods("POST")
    router.HandleFunc("/GetJobCloudRegion/{jobId}", GetJobCloudRegion).Methods("GET")

    // Initialize db connection
    initDB()

    // config port
    fmt.Printf("Starting server at 8000 \n")
    http.ListenAndServe(":8000", router)
}

I try to create unit test file below, but it is not running it shows like this
enter image description here

main_test.go:

package main

import (
    "log"
    "os"
    "testing"

    "github.com/jinzhu/gorm"
    _ "github.com/jinzhu/gorm/dialects/mysql"
)

func TestinitDB(m *testing.M) {
    dataSourceName := "root:@tcp(localhost:3306)/?parseTime=True"
    db, err := gorm.Open("mysql", dataSourceName)

    if err != nil {
        log.Fatal("failed to connect database")
    }
    //db.Exec("CREATE DATABASE test")
    db.LogMode(true)
    db.Exec("USE test111")
    os.Exit(m.Run())
}

Please help me to write unit test file

Solution

"How to unit test" is a very broad question since it depends on what you want to test. In your example you’re working with remote connections to a database which is usually something that is mocked in unit testing. It’s not clear if that’s what you’re looking for and it’s not a requirement to do so either. By seeing you use different databases I would expect the intention is not to mock.

Start by looking at this post that has already answered your question around how TestMain and testing.M is intended to work.

What your code currently does (if your test name would be named TestMain properly) is add a method around your other tests to do setup and teardown, however you don’t have any other tests to make use of this setup and teardown thus you’ll get the result no tests to run.

It’s not a part of your question but I would suggest try to avoid testing.M until you feel confident in testing Go code. Using testing.T and testing separate units might be easier to understand. You could achieve pretty much the same thing by just calling initDB() in your test and making the initializer take an argument.

func initDB(dbToUse string) {
    // ...
    db.Exec("USE "+dbToUse)
}

You would then call initDB("test") from your main file and initDB("test111") from your test.
You can read about the testing package for Go at pkg.go.dev/testing where you’ll also find the differences between testing.T and testing.M.

Here’s a shorter example with some basic testing that does not require any setup or teardown and that uses testing.T instead of testing.M.

main.go

package main

import "fmt"

func main() {
    fmt.Println(add(1, 2))
}

func add(a, b int) int {
    return a + b
}

main_test.go

package main

import "testing"

func TestAdd(t *testing.T) {
    t.Run("add 2 + 2", func(t *testing.T) {
        want := 4

        // Call the function you want to test.
        got := add(2, 2)

        // Assert that you got your expected response
        if got != want {
            t.Fail()
        }
    })
}

This test will test your method add and ensure it returns the right value when you pass 2, 2 as argument. The use of t.Run is optional but it creates a sub test for you which makes reading the output a bit easier.

Since you test on package level you’ll need to specify what package to test if you’re not using the triple dot format including every package recursively.

To run the test in the example above, specify your package and -v for verbose output.

$ go test ./ -v
=== RUN   TestAdd
=== RUN   TestAdd/add_2_+_2
--- PASS: TestAdd (0.00s)
    --- PASS: TestAdd/add_2_+_2 (0.00s)
PASS
ok      x       (cached)

There is a lot more to learn around this topic as well such as testing frameworks and testing patterns. As an example the testing framework testify helps you do assertions and prints nice output when tests fail and table driven tests is a pretty common pattern in Go.

You’re also writing a HTTP server which usually requires additional testing setup to test properly. Luckily the http package in standard library comes with a sub package named httptest which can help you record external requests or start local servers for external requests. You can also test your handlers by directly calling your handlers with a manually constructed request.

It would look something like this.

func TestSomeHandler(t *testing.T) {
    // Create a request to pass to our handler. We don't have any query parameters for now, so we'll
    // pass 'nil' as the third parameter.
    req, err := http.NewRequest("GET", "/some-endpoint", nil)
    if err != nil {
        t.Fatal(err)
    }

    // We create a ResponseRecorder (which satisfies http.ResponseWriter) to record the response.
    rr := httptest.NewRecorder()
    handler := http.HandlerFunc(SomeHandler)

    // Our handlers satisfy http.Handler, so we can call their ServeHTTP method 
    // directly and pass in our Request and ResponseRecorder.
    handler.ServeHTTP(rr, req)

    // Check the status code is what we expect.
    if status := rr.Code; status != http.StatusOK {
        t.Errorf("handler returned wrong status code: got %v want %v",
            status, http.StatusOK)
    }

Now, to test some of your code. We can run the init method and call any of your services with a response recorder.

package main

import (
    "encoding/json"
    "net/http"
    "net/http/httptest"
    "testing"
)

func TestGetAllJobs(t *testing.T) {
    // Initialize the DB
    initDB("test111")

    req, err := http.NewRequest("GET", "/GetAllJobs", nil)
    if err != nil {
        t.Fatal(err)
    }

    rr := httptest.NewRecorder()
    handler := http.HandlerFunc(GetAllJobs)

    handler.ServeHTTP(rr, req)

    // Check the status code is what we expect.
    if status := rr.Code; status != http.StatusOK {
        t.Errorf("handler returned wrong status code: got %v want %v",
            status, http.StatusOK)
    }

    var response []Jobs
    if err := json.Unmarshal(rr.Body.Bytes(), &response); err != nil {
        t.Errorf("got invalid response, expected list of jobs, got: %v", rr.Body.String())
    }

    if len(response) < 1 {
        t.Errorf("expected at least 1 job, got %v", len(response))
    }

    for _, job := range response {
        if job.SourcePath == "" {
            t.Errorf("expected job id %d to  have a source path, was empty", job.JobID)
        }
    }
}

Answered By – Simon S.

Answer Checked By – Marilyn (GoLangFix Volunteer)

Leave a Reply

Your email address will not be published.