Golang:将“切片”作为参考问题传递

I'm trying to write a program that counts inversions within an array, but my array is not being sorted properly due to reference issues and thus messes up my count even though I thought slices were passed by reference in Golang.

Here is my code:

package main

import (
    "fmt"
)

func InversionCount(a []int) int {
    if len(a) <= 1 {
        return 0
    }
    mid := len(a) / 2
    left := a[:mid]
    right := a[mid:]
    leftCount := InversionCount(left) //not being sorted properly due to reference issues 
    rightCount := InversionCount(right) //not being sorted properly due to reference issues

    res := make([]int, 0, len(right)+len(left)) //temp slice to hold the sorted left side and right side

    iCount := mergeCount(left, right, &res)

    a = res        //assigns the original slice with the temp slice values
    fmt.Println(a) //a in the end is not sorted properly for most cases 
    return iCount + leftCount + rightCount
}

    func mergeCount(left, right []int, res *[]int) int {
        count := 0

        for len(left) > 0 || len(right) > 0 {
            if len(left) == 0 {
                *res = append(*res, right...)
                break
            }
            if len(right) == 0 {
                *res = append(*res, left...)
                break
            }
        if left[0] <= right[0] {
            *res = append(*res, left[0])
            left = left[1:]
        } else { //Inversion has been found
            count += len(left)
            *res = append(*res, right[0])
            right = right[1:]
        }
    }

    return count
}

func main() {
    test := []int{4,2,3,1,5}
    fmt.Print(InversionCount(test))
}

What would be the best possible way to solve this problem? I have tried to do something similar to what I did to the res array by forcing the mergeCountfunction to take in a reference of the array, but it seems very messy and it will give me errors.

You either have to pass a pointer to your slice like:

func InversionCount(a *[]int) int {
    if len(*a) <= 1 {
        return 0
    }
    mid := len(*a) / 2
    left := (*a)[:mid]
    right := (*a)[mid:]
    leftCount := InversionCount(&left)   //not being sorted properly due to reference issues
    rightCount := InversionCount(&right) //not being sorted properly due to reference issues

    res := make([]int, 0, len(right)+len(left)) //temp slice to hold the sorted left side and right side

    iCount := mergeCount(left, right, &res)

    *a = res
    fmt.Println(a) //a in the end is not sorted properly for most cases
    return iCount + leftCount + rightCount
}

playground

Or use copy and change a = res to copy(a, res).

playground

Rather than mutate the slices, I'd just have the functions return the slices obtained during the merge step.

Here's code in that form, including some unit-test-like code which compares the efficient version with a naive O(N^2) count.

package main

import "fmt"

// Inversions returns the input sorted, and the number of inversions found.
func Inversions(a []int) ([]int, int) {
    if len(a) <= 1 {
        return a, 0
    }
    left, lc := Inversions(a[:len(a)/2])
    right, rc := Inversions(a[len(a)/2:])
    merge, mc := mergeCount(left, right)
    return merge, lc + rc + mc
}

func mergeCount(left, right []int) ([]int, int) {
    res := make([]int, 0, len(left)+len(right))
    n := 0
    for len(left) > 0 && len(right) > 0 {
        if left[0] >= right[0] {
            res = append(res, left[0])
            left = left[1:]
        } else {
            res = append(res, right[0])
            right = right[1:]
            n += len(left)
        }
    }
    return append(append(res, left...), right...), n
}

func dumbInversions(a []int) int {
    n := 0
    for i := range a {
        for j := i + 1; j < len(a); j++ {
            if a[i] < a[j] {
                n++
            }
        }
    }
    return n
}

func main() {
    cases := [][]int{
        {},
        {1},
        {1, 2, 3, 4, 5},
        {2, 1, 3, 4, 5},
        {5, 4, 3, 2, 1},
        {2, 2, 1, 1, 3, 3, 4, 4, 1, 1},
    }
    for _, c := range cases {
        want := dumbInversions(c)
        _, got := Inversions(c)
        if want != got {
            fmt.Printf("Inversions(%v)=%d, want %d
", c, got, want)
        }
    }
}