#! /usr/bin/env python

import openturns as ot
import math as m

ot.TESTPREAMBLE()


# DEFAULT CONSTRUCTOR AND STRING CONVERTER
print("test 0 : default constructor and string converter")

# Default constructor
matrix0 = ot.ComplexMatrix()

# String converter
print("matrix0 = ", repr(matrix0))

# CONSTRUCTOR WITH SIZE, OPERATOR() AND STRING CONVERTER
print("test number one : constructor with size, operator() and string converter")

# Constructor with size
matrix1 = ot.ComplexMatrix(2, 2)

# Check operator() methods
matrix1[0, 0] = 1.0 + 1j
matrix1[1, 0] = 2.0 + 4j
matrix1[0, 1] = 3.0 - 1j
matrix1[1, 1] = 4.0

# String converter
print("matrix1 = ", repr(matrix1))

# COPY CONSTRUCTOR AND STRING CONVERTER
print("test 2 : copy constructor and string converter")

# Copy constructor
matrix2 = ot.ComplexMatrix(matrix1)

# String converter
print("matrix2 = ", repr(matrix2))

# GET DIMENSIONS METHODS
print("test 3 : dimension methods")

# Get dimension methods
print("matrix1's nbRows = ", matrix1.getNbRows())
print("matrix1's nbColumns = ", matrix1.getNbColumns())

# CONSTRUCTOR WITH COLLECTION
print("test 4 : constructor with collection method")

# Create the collection of values
elementsValues = ot.ComplexCollection()
elementsValues.add(1.0 - 1j)
elementsValues.add(2.0 - 1j)
elementsValues.add(3.0 - 1j)
elementsValues.add(4.0 + 1j)
elementsValues.add(5.0 + 1j)
elementsValues.add(6.0 + 1j)

# Check the content of the collection
print("elementsValues = ", repr(elementsValues))

# Check the constructor with collection
matrix0bis = ot.ComplexMatrix(2, 2, elementsValues)
print("matrix0bis = ", repr(matrix0bis))

# TRANSPOSITION METHOD AND CONJUGATE METHOD
print("test 5 : transposition / conjugate method")

# Check transpose method
matrix4 = matrix1.transpose()
matrix5 = matrix1.conjugate()
print("matrix1 transposed = ", repr(matrix4))
print("matrix1 conjugated = ", repr(matrix5))

# TRANSPOSITION AND CONJUGATE COUPLED METHOD
print("transposition and conjugate method")

# Check transpose method
matrix6 = matrix1.conjugateTranspose()
print("matrix1 conjugated and transposed = ", repr(matrix6))

# ADDITION METHOD
print("test 6 : addition method")

# Check addition method : we check the operator and the symmetry of the
# operator, thus testing the comparison operator
sum1 = matrix1 + matrix4
sum2 = matrix4 + matrix1
print("sum1 = ", repr(sum1))
print("sum2 = ", repr(sum2))
print("sum1 equals sum2 = ", sum1 == sum2)

# SUBTRACTION METHOD
print("test 7 : subtraction method")

# Check subtraction method
diff = matrix1 - matrix4
print("diff = ", repr(diff))

#  MATRIX MULTIPLICATION METHOD
print("test 8 : matrix multiplication method")

# Check multiplication method
prod = matrix1 * matrix4
print("prod = ", repr(prod))

# MULTIPLICATION WITH A NUMERICAL POINT METHOD
print("test 9 : multiplication with a numerical point method")

# Create the numerical point
pt = ot.Point()
pt.add(1.0)
pt.add(2.0)
print("pt = ", repr(pt))

# Check the product method
ptResult = matrix1 * pt
print("ptResult = ", repr(ptResult))

#  MULTIPLICATION AND DIVISION BY A NUMERICAL SCALAR METHODS
print("test 10 : multiplication and division by a numerical scalar methods")

# Check the multiplication method
s = 3.0 + 1j
scalprod1 = matrix1 * s
print("scalprod1 = ", repr(scalprod1))

# Check the division method
scaldiv1 = matrix1 / s
print("scaldiv1 = ", repr(scaldiv1))

#  ISEMPTY METHOD
print("test 10 : isEmpty method")

# Check method isEmpty
matrix7 = ot.ComplexMatrix()
matrix8 = ot.ComplexMatrix()

print("matrix1 is empty = ", matrix1.isEmpty())
print("matrix5 is empty = ", matrix7.isEmpty())
print("matrix6 is empty = ", matrix8.isEmpty())
print("matrix0 is empty = ", matrix0.isEmpty())

# MULTIPLICATION WITH A NUMERICAL POINT METHOD
print("test 11 : multiplication with a numerical point method")

# Create the numerical point
pt_test = ot.Point()
pt_test.add(1.0)
pt_test.add(2.0)
print("pt_test = ", repr(pt_test))

A = ot.ComplexMatrix(2, 2)
A[0, 0] = 0.5
A[1, 0] = -(m.sqrt(3.0) / 2)
A[0, 1] = m.sqrt(3.0) / 2
A[1, 1] = 0.5
B = A.transpose()
identity = B * A

# Check the product method
ptResult2 = identity * pt_test
print("A = ", repr(A))
print("B = ", repr(B))
print("identity = ", repr(identity))
print("ptResult2 = ", repr(ptResult2))
