1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105
|
package tensor
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestCS_Basics(t *testing.T) {
assert := assert.New(t)
xs0 := []int{1, 2, 6, 8}
ys0 := []int{1, 2, 1, 6}
xs1 := []int{1, 2, 6, 8}
ys1 := []int{1, 2, 1, 6}
vals0 := []float64{3, 1, 4, 1}
vals1 := []float64{3, 1, 4, 1}
var T0, T1 *CS
var d0, d1 *Dense
var dp0, dp1 *Dense
var err error
fails := func() {
CSCFromCoord(Shape{7, 6}, xs0, ys0, vals0)
}
assert.Panics(fails)
// Test CSC
T0 = CSCFromCoord(Shape{9, 7}, xs0, ys0, vals0)
d0 = T0.Dense()
T0.T()
dp0 = T0.Dense()
T0.UT() // untranspose as Materialize() will be called below
// Test CSR
fails = func() {
CSRFromCoord(Shape{7, 6}, xs1, ys1, vals1)
}
T1 = CSRFromCoord(Shape{9, 7}, xs1, ys1, vals1)
d1 = T1.Dense()
T1.T()
dp1 = T1.Dense()
T1.UT()
t.Logf("%v %v", T0.indptr, T0.indices)
t.Logf("%v %v", T1.indptr, T1.indices)
assert.True(d0.Eq(d1), "%+#v\n %+#v\n", d0, d1)
assert.True(dp0.Eq(dp1))
assert.True(T1.Eq(T1))
assert.False(T0.Eq(T1))
// At
var got interface{}
correct := float64(3.0)
if got, err = T0.At(1, 1); err != nil {
t.Error(err)
}
if got.(float64) != correct {
t.Errorf("Expected %v. Got %v - T0[1,1]", correct, got)
}
if got, err = T1.At(1, 1); err != nil {
t.Error(err)
}
if got.(float64) != correct {
t.Errorf("Expected %v. Got %v - T1[1,1]", correct, got)
}
correct = 0.0
if got, err = T0.At(3, 3); err != nil {
t.Error(err)
}
if got.(float64) != correct {
t.Errorf("Expected %v. Got %v - T0[3,3]", correct, got)
}
if got, err = T1.At(3, 3); err != nil {
t.Error(err)
}
if got.(float64) != correct {
t.Errorf("Expected %v. Got %v - T1[3,3]", correct, got)
}
// Test clone
T2 := T0.Clone()
assert.True(T0.Eq(T2))
// Scalar representation
assert.False(T0.IsScalar())
fails = func() {
T0.ScalarValue()
}
assert.Panics(fails)
assert.Equal(len(vals0), T0.NonZeroes())
// Sparse Iterator
it := T0.Iterator()
var valids []int
correctValids := []int{0, 2, 1, 3}
for i, valid, err := it.NextValidity(); err == nil; i, valid, err = it.NextValidity() {
if valid {
valids = append(valids, i)
}
}
assert.Equal(correctValids, valids)
}
|