/*=========================================================================

  Program:   Visualization Toolkit
  Module:    $RCSfile: vtkGridTransform.cxx,v $
  Language:  C++
  Date:      $Date: 2000/12/10 20:08:24 $
  Version:   $Revision: 1.9 $
  Thanks:    Thanks to David G. Gobbi who developed this class.

Copyright (c) 1993-2001 Ken Martin, Will Schroeder, Bill Lorensen.

This software is copyrighted by Ken Martin, Will Schroeder and Bill Lorensen.
The following terms apply to all files associated with the software unless
explicitly disclaimed in individual files. This copyright specifically does
not apply to the related textbook "The Visualization Toolkit" ISBN
013199837-4 published by Prentice Hall which is covered by its own copyright.

The authors hereby grant permission to use, copy, and distribute this
software and its documentation for any purpose, provided that existing
copyright notices are retained in all copies and that this notice is included
verbatim in any distributions. Additionally, the authors grant permission to
modify this software and its documentation for any purpose, provided that
such modifications are not distributed without the explicit consent of the
authors and that existing copyright notices are retained in all copies. Some
of the algorithms implemented by this software are patented, observe all
applicable patent law.

IN NO EVENT SHALL THE AUTHORS OR DISTRIBUTORS BE LIABLE TO ANY PARTY FOR
DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT
OF THE USE OF THIS SOFTWARE, ITS DOCUMENTATION, OR ANY DERIVATIVES THEREOF,
EVEN IF THE AUTHORS HAVE BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

THE AUTHORS AND DISTRIBUTORS SPECIFICALLY DISCLAIM ANY WARRANTIES, INCLUDING,
BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY, FITNESS FOR A
PARTICULAR PURPOSE, AND NON-INFRINGEMENT.  THIS SOFTWARE IS PROVIDED ON AN
"AS IS" BASIS, AND THE AUTHORS AND DISTRIBUTORS HAVE NO OBLIGATION TO PROVIDE
MAINTENANCE, SUPPORT, UPDATES, ENHANCEMENTS, OR MODIFICATIONS.

=========================================================================*/
#include "vtkGridTransform.h"
#include "vtkObjectFactory.h"
#include "vtkMath.h"
#include "math.h"

//----------------------------------------------------------------------------
vtkGridTransform* vtkGridTransform::New()
{
  // First try to create the object from the vtkObjectFactory
  vtkObject* ret = vtkObjectFactory::CreateInstance("vtkGridTransform");
  if(ret)
    {
    return (vtkGridTransform*)ret;
    }
  // If the factory was unable to create the object, then create it here.
  return new vtkGridTransform;
}

//----------------------------------------------------------------------------
// fast floor() function for converting a float to an int
// (the floor() implementation on some computers is much slower than this,
// because they require some 'exact' behaviour that we don't).

static inline int vtkGridFloor(float x, float &f)
{
  int ix = int(x);
  f = x-ix;
  if (f < 0) { f = x - (--ix); }

  return ix;
}

static inline int vtkGridFloor(float x)
{
  int ix = int(x);
  if (x-ix < 0) { ix--; }

  return ix;
}

//----------------------------------------------------------------------------
// Nearest-neighbor interpolation of a displacement grid.
// The displacement as well as the derivatives are returned.
// There are two versions: one which computes the derivatives,
// and one which doesn't.

template <class T>
static inline void vtkNearestHelper(float displacement[3], T *gridPtr, 
				    int increment)
{
  gridPtr += increment;
  displacement[0] = gridPtr[0];
  displacement[1] = gridPtr[1];
  displacement[2] = gridPtr[2];
}

static inline void vtkNearestNeighborInterpolation(float point[3], 
					     float displacement[3],
					     void *gridPtr, int gridType,
					     int gridExt[6], int gridInc[3])
{
  int gridId[3];
  gridId[0] = vtkGridFloor(point[0]+0.5f)-gridExt[0];
  gridId[1] = vtkGridFloor(point[1]+0.5f)-gridExt[2];
  gridId[2] = vtkGridFloor(point[2]+0.5f)-gridExt[4];
  
  int ext[3];
  ext[0] = gridExt[1]-gridExt[0];
  ext[1] = gridExt[3]-gridExt[2];
  ext[2] = gridExt[5]-gridExt[4];

  // do bounds check, most points will be inside so optimize for that
  if ((gridId[0] | (ext[0] - gridId[0]) |
       gridId[1] | (ext[1] - gridId[1]) |
       gridId[2] | (ext[2] - gridId[2])) < 0)
    {
    for (int i = 0; i < 3; i++)
      {
      if (gridId[i] < 0)
	{
	gridId[i] = 0; 
	}
      else if (gridId[i] > ext[i])
	{
	gridId[i] = ext[i];
	}
      }
    }

  // do nearest-neighbor interpolation
  int increment = gridId[0]*gridInc[0] + 
                  gridId[1]*gridInc[1] + 
                  gridId[2]*gridInc[2];

  switch (gridType)
    {
    case VTK_CHAR:
      vtkNearestHelper(displacement, (char *)gridPtr, increment);
      break;
    case VTK_UNSIGNED_CHAR:
      vtkNearestHelper(displacement, (unsigned char *)gridPtr, increment); 
      break;
    case VTK_SHORT:
      vtkNearestHelper(displacement, (short *)gridPtr, increment);
      break;
    case VTK_UNSIGNED_SHORT:
      vtkNearestHelper(displacement, (unsigned short *)gridPtr, increment);
      break;
    case VTK_FLOAT:
      vtkNearestHelper(displacement, (float *)gridPtr, increment);
      break;
    }
}

template <class T>
static inline void vtkNearestHelper(float displacement[3],
				    float derivatives[3][3],
				    T *gridPtr, int gridId[3],
				    int gridId0[3], int gridId1[3],
				    int gridInc[3])
{
  int incX = gridId[0]*gridInc[0];
  int incY = gridId[1]*gridInc[1];
  int incZ = gridId[2]*gridInc[2];

  T *gridPtr0;
  T *gridPtr1 = gridPtr + incX + incY + incZ;

  displacement[0] = gridPtr1[0];
  displacement[1] = gridPtr1[1];
  displacement[2] = gridPtr1[2];

  int incX0 = gridId0[0]*gridInc[0];
  int incX1 = gridId1[0]*gridInc[0];
  int incY0 = gridId0[1]*gridInc[1];

  int incY1 = gridId1[1]*gridInc[1];
  int incZ0 = gridId0[2]*gridInc[2];
  int incZ1 = gridId1[2]*gridInc[2];

  gridPtr0 = gridPtr + incX0 + incY + incZ;
  gridPtr1 = gridPtr + incX1 + incY + incZ;

  derivatives[0][0] = gridPtr1[0] - gridPtr0[0];
  derivatives[1][0] = gridPtr1[1] - gridPtr0[1];
  derivatives[2][0] = gridPtr1[2] - gridPtr0[2];

  gridPtr0 = gridPtr + incX + incY0 + incZ;
  gridPtr1 = gridPtr + incX + incY1 + incZ;

  derivatives[0][1] = gridPtr1[0] - gridPtr0[0];
  derivatives[1][1] = gridPtr1[1] - gridPtr0[1];
  derivatives[2][1] = gridPtr1[2] - gridPtr0[2];

  gridPtr0 = gridPtr + incX + incY + incZ0;
  gridPtr1 = gridPtr + incX + incY + incZ1;

  derivatives[0][2] = gridPtr1[0] - gridPtr0[0];
  derivatives[1][2] = gridPtr1[1] - gridPtr0[1];
  derivatives[2][2] = gridPtr1[2] - gridPtr0[2];
}

static void vtkNearestNeighborInterpolation(float point[3], 
					    float displacement[3],
					    float derivatives[3][3],
					    void *gridPtr, int gridType,
					    int gridExt[6], int gridInc[3])
{
  if (derivatives == NULL)
    {
    vtkNearestNeighborInterpolation(point,displacement,gridPtr,gridType,
				    gridExt,gridInc);
    return;
    }

  float f[3];
  int gridId0[3];
  gridId0[0] = vtkGridFloor(point[0],f[0])-gridExt[0];
  gridId0[1] = vtkGridFloor(point[1],f[1])-gridExt[2];
  gridId0[2] = vtkGridFloor(point[2],f[2])-gridExt[4];

  int gridId[3], gridId1[3];
  gridId[0] = gridId1[0] = gridId0[0] + 1;
  gridId[1] = gridId1[1] = gridId0[1] + 1;
  gridId[2] = gridId1[2] = gridId0[2] + 1;

  if (f[0] < 0.5) 
    {
    gridId[0] = gridId0[0];
    }
  if (f[1] < 0.5) 
    {
    gridId[1] = gridId0[1];
    }
  if (f[2] < 0.5) 
    {
    gridId[2] = gridId0[2];
    }
  
  int ext[3];
  ext[0] = gridExt[1] - gridExt[0];
  ext[1] = gridExt[3] - gridExt[2];
  ext[2] = gridExt[5] - gridExt[4];

  // do bounds check, most points will be inside so optimize for that
  if ((gridId0[0] | (ext[0] - gridId1[0]) |
       gridId0[1] | (ext[1] - gridId1[1]) |
       gridId0[2] | (ext[2] - gridId1[2])) < 0)
    {
    for (int i = 0; i < 3; i++) 
      {
      if (gridId0[i] < 0)
        {
	gridId[i] = 0;
        gridId0[i] = 0;
        gridId1[i] = 0;
        }
      else if (gridId1[i] > ext[i])
        {
        gridId[i] = ext[i];
        gridId0[i] = ext[i];
        gridId1[i] = ext[i];
        }
      }
    }

  // do nearest-neighbor interpolation
  switch (gridType)
    {
    case VTK_CHAR:
      vtkNearestHelper(displacement, derivatives, (char *)gridPtr, 
		       gridId, gridId0, gridId1, gridInc);
      break;
    case VTK_UNSIGNED_CHAR:
      vtkNearestHelper(displacement, derivatives, (unsigned char *)gridPtr, 
		       gridId, gridId0, gridId1, gridInc);
      break;
    case VTK_SHORT:
      vtkNearestHelper(displacement, derivatives, (short *)gridPtr, 
		       gridId, gridId0, gridId1, gridInc);
      break;
    case VTK_UNSIGNED_SHORT:
      vtkNearestHelper(displacement, derivatives, (unsigned short *)gridPtr, 
		       gridId, gridId0, gridId1, gridInc);
      break;
    case VTK_FLOAT:
      vtkNearestHelper(displacement, derivatives, (float *)gridPtr, 
		       gridId, gridId0, gridId1, gridInc);
      break;
    }
}

//----------------------------------------------------------------------------
// Trilinear interpolation of a displacement grid.
// The displacement as well as the derivatives are returned.

template <class T>
static inline void vtkLinearHelper(float displacement[3], 
				   float derivatives[3][3],
				   float fx, float fy, float fz, T *gridPtr, 
				   int i000, int i001, int i010, int i011,
				   int i100, int i101, int i110, int i111)
{
  float rx = 1 - fx;
  float ry = 1 - fy;
  float rz = 1 - fz;
  
  float ryrz = ry*rz;
  float ryfz = ry*fz;
  float fyrz = fy*rz;
  float fyfz = fy*fz;

  float rxryrz = rx*ryrz;
  float rxryfz = rx*ryfz;
  float rxfyrz = rx*fyrz;
  float rxfyfz = rx*fyfz;
  float fxryrz = fx*ryrz;
  float fxryfz = fx*ryfz;
  float fxfyrz = fx*fyrz;
  float fxfyfz = fx*fyfz;

  if (!derivatives)
    {
    int i = 3;
    do
      {
      *displacement++ = (rxryrz*gridPtr[i000] + rxryfz*gridPtr[i001] +
			 rxfyrz*gridPtr[i010] + rxfyfz*gridPtr[i011] +
			 fxryrz*gridPtr[i100] + fxryfz*gridPtr[i101] +
			 fxfyrz*gridPtr[i110] + fxfyfz*gridPtr[i111]);
      gridPtr++;
      }
    while (--i);
    }
  else
    {
    float rxrz = rx*rz;
    float rxfz = rx*fz;
    float fxrz = fx*rz;
    float fxfz = fx*fz;
    
    float rxry = rx*ry;
    float rxfy = rx*fy;
    float fxry = fx*ry;
    float fxfy = fx*fy;
    
    float *derivative = *derivatives;

    int i = 3;
    do
      {
      *displacement++ = (rxryrz*gridPtr[i000] + rxryfz*gridPtr[i001] +
			 rxfyrz*gridPtr[i010] + rxfyfz*gridPtr[i011] +
			 fxryrz*gridPtr[i100] + fxryfz*gridPtr[i101] +
			 fxfyrz*gridPtr[i110] + fxfyfz*gridPtr[i111]);

      *derivative++ = (ryrz*(gridPtr[i100] - gridPtr[i000]) +
		       ryfz*(gridPtr[i101] - gridPtr[i001]) +
		       fyrz*(gridPtr[i110] - gridPtr[i010]) +
		       fyfz*(gridPtr[i111] - gridPtr[i011]));
      
      *derivative++ = (rxrz*(gridPtr[i010] - gridPtr[i000]) +
		       rxfz*(gridPtr[i011] - gridPtr[i001]) +
		       fxrz*(gridPtr[i110] - gridPtr[i100]) +
		       fxfz*(gridPtr[i111] - gridPtr[i101]));
      
      *derivative++ = (rxry*(gridPtr[i001] - gridPtr[i000]) +
		       rxfy*(gridPtr[i011] - gridPtr[i010]) +
		       fxry*(gridPtr[i101] - gridPtr[i100]) +
		       fxfy*(gridPtr[i111] - gridPtr[i110]));

      gridPtr++;
      }
    while (--i);
    }
}

static void vtkTrilinearInterpolation(float point[3], 
				      float displacement[3],
				      float derivatives[3][3],
				      void *gridPtr, int gridType, 
				      int gridExt[6], int gridInc[3])
{
  // change point into integer plus fraction
  float f[3];
  int floorX = vtkGridFloor(point[0],f[0]);
  int floorY = vtkGridFloor(point[1],f[1]);
  int floorZ = vtkGridFloor(point[2],f[2]);

  int gridId0[3];
  gridId0[0] = floorX - gridExt[0];
  gridId0[1] = floorY - gridExt[2];
  gridId0[2] = floorZ - gridExt[4];

  int gridId1[3];
  gridId1[0] = gridId0[0] + 1;
  gridId1[1] = gridId0[1] + 1;
  gridId1[2] = gridId0[2] + 1;

  int ext[3];
  ext[0] = gridExt[1] - gridExt[0];
  ext[1] = gridExt[3] - gridExt[2];
  ext[2] = gridExt[5] - gridExt[4];

  // do bounds check, most points will be inside so optimize for that
  if ((gridId0[0] | (ext[0] - gridId1[0]) |
       gridId0[1] | (ext[1] - gridId1[1]) |
       gridId0[2] | (ext[2] - gridId1[2])) < 0)
    {
    for (int i = 0; i < 3; i++)
      {
      if (gridId0[i] < 0)
        {
        gridId0[i] = 0;
        gridId1[i] = 0;
        f[i] = 0;
        }
      else if (gridId1[i] > ext[i])
        {
        gridId0[i] = ext[i];
        gridId1[i] = ext[i];
        f[i] = 0;
        }
      }
    }

  // do trilinear interpolation
  int factX0 = gridId0[0]*gridInc[0];
  int factY0 = gridId0[1]*gridInc[1];
  int factZ0 = gridId0[2]*gridInc[2];

  int factX1 = gridId1[0]*gridInc[0];
  int factY1 = gridId1[1]*gridInc[1];
  int factZ1 = gridId1[2]*gridInc[2];
    
  int i000 = factX0+factY0+factZ0;
  int i001 = factX0+factY0+factZ1;
  int i010 = factX0+factY1+factZ0;
  int i011 = factX0+factY1+factZ1;
  int i100 = factX1+factY0+factZ0;
  int i101 = factX1+factY0+factZ1;
  int i110 = factX1+factY1+factZ0;
  int i111 = factX1+factY1+factZ1;
  
  switch (gridType)
    {
    case VTK_CHAR:
      vtkLinearHelper(displacement, derivatives, f[0], f[1], f[2], 
		      (char *)gridPtr,
		      i000, i001, i010, i011, i100, i101, i110, i111);
      break;
    case VTK_UNSIGNED_CHAR:
      vtkLinearHelper(displacement, derivatives, f[0], f[1], f[2], 
		      (unsigned char *)gridPtr,
		      i000, i001, i010, i011, i100, i101, i110, i111);
      break;
    case VTK_SHORT:
      vtkLinearHelper(displacement, derivatives, f[0], f[1], f[2], 
		      (short *)gridPtr, 
		      i000, i001, i010, i011, i100, i101, i110, i111);
      break;
    case VTK_UNSIGNED_SHORT:
      vtkLinearHelper(displacement, derivatives, f[0], f[1], f[2], 
		      (unsigned short *)gridPtr,
		      i000, i001, i010, i011, i100, i101, i110, i111);
      break;
    case VTK_FLOAT:
      vtkLinearHelper(displacement, derivatives, f[0], f[1], f[2], 
		      (float *)gridPtr,
		      i000, i001, i010, i011, i100, i101, i110, i111);
      break;
    }
}

//----------------------------------------------------------------------------
// Do tricubic interpolation of the input data 'gridPtr' of extent 'gridExt' 
// at the 'point'.  The result is placed at 'outPtr'.  
// The number of scalar components in the data is 'numscalars'

// The tricubic interpolation ensures that both the intensity and
// the first derivative of the intensity are smooth across the
// image.  The first derivative is estimated using a 
// centered-difference calculation.


// helper function: set up the lookup indices and the interpolation 
// coefficients

void vtkSetTricubicInterpCoeffs(float F[4], int *l, int *m, float f, 
				int interpMode)
{   
  float fp1,fm1,fm2;

  switch (interpMode)
    {
    case 7:     // cubic interpolation
      *l = 0; *m = 4; 
      fm1 = f-1;
      F[0] = -f*fm1*fm1/2;
      F[1] = ((3*f-2)*f-2)*fm1/2;
      F[2] = -((3*f-4)*f-1)*f/2;
      F[3] = f*f*fm1/2;
      break;
    case 0:     // no interpolation
    case 2:
    case 4:
    case 6:
      *l = 1; *m = 2;
      F[0] = 0;
      F[1] = 1;
      F[2] = 0;
      F[3] = 0;
      break;
    case 1:     // linear interpolation
      *l = 1; *m = 3;
      F[0] = 0;
      F[1] = 1-f;
      F[2] = f;
      F[3] = 0;
      break;
    case 3:     // quadratic interpolation
      *l = 1; *m = 4; 
      fm1 = f-1; fm2 = fm1-1;
      F[0] = 0;
      F[1] = fm1*fm2/2;
      F[2] = -f*fm2;
      F[3] = f*fm1/2;
      break;
    case 5:     // quadratic interpolation
      *l = 0; *m = 3; 
      fp1 = f+1; fm1 = f-1; 
      F[0] = f*fm1/2;
      F[1] = -fp1*fm1;
      F[2] = fp1*f/2;
      F[3] = 0;
      break;
    }
}

// set coefficients to be used to find the derivative of the cubic
void vtkSetTricubicDerivCoeffs(float F[4], float G[4], int *l, int *m, 
			       float f, int interpMode)
{   
  float fp1,fm1,fm2;

  switch (interpMode)
    {
    case 7:     // cubic interpolation
      *l = 0; *m = 4; 
      fm1 = f-1;
      F[0] = -f*fm1*fm1/2;
      F[1] = ((3*f-2)*f-2)*fm1/2;
      F[2] = -((3*f-4)*f-1)*f/2;
      F[3] = f*f*fm1/2;
      G[0] = -((3*f-4)*f+1)/2;
      G[1] =  (9*f-10)*f/2;
      G[2] = -((9*f-8)*f-1)/2;
      G[3] =  (3*f-2)*f/2;
      break;
    case 0:     // no interpolation
    case 2:
    case 4:
    case 6:
      *l = 1; *m = 2;
      F[0] = 0;
      F[1] = 1;
      F[2] = 0;
      F[3] = 0;
      G[0] = 0;
      G[1] = 0;
      G[2] = 0;
      G[3] = 0;
      break;
    case 1:     // linear interpolation
      *l = 1; *m = 3;
      F[0] = 0;
      F[1] = 1-f;
      F[2] = f;
      F[3] = 0;
      G[0] =  0;
      G[1] = -1;
      G[2] =  1;
      G[3] =  0;
      break;
    case 3:     // quadratic interpolation
      *l = 1; *m = 4; 
      fm1 = f-1; fm2 = fm1-1;
      F[0] = 0;
      F[1] = fm1*fm2/2;
      F[2] = -f*fm2;
      F[3] = f*fm1/2;
      G[0] = 0;
      G[1] = f-1.5;
      G[2] = 2-2*f;
      G[3] = f-0.5;
      break;
    case 5:     // quadratic interpolation
      *l = 0; *m = 3; 
      fp1 = f+1; fm1 = f-1; 
      F[0] = f*fm1/2;
      F[1] = -fp1*fm1;
      F[2] = fp1*f/2;
      F[3] = 0;
      G[0] = f-0.5;
      G[1] = -2*f;
      G[2] = f+0.5;
      G[3] = 0;
      break;
    }
}

// tricubic interpolation of a warp grid with derivatives
// (set derivatives to NULL to avoid computing them).

template <class T>
static inline void vtkCubicHelper(float displacement[3], 
				  float derivatives[3][3],
				  float fx, float fy, float fz, T *gridPtr,
				  int interpModeX, int interpModeY, 
				  int interpModeZ,
				  int factX[4], int factY[4], int factZ[4])
{
  float fX[4],fY[4],fZ[4];
  float gX[4],gY[4],gZ[4];
  int jl,jm,kl,km,ll,lm;

  if (derivatives)
    {
    for (int i = 0; i < 3; i++)
      {
      derivatives[i][0] = 0.0f; 
      derivatives[i][1] = 0.0f; 
      derivatives[i][2] = 0.0f;
      }
    vtkSetTricubicDerivCoeffs(fX,gX,&ll,&lm,fx,interpModeX);
    vtkSetTricubicDerivCoeffs(fY,gY,&kl,&km,fy,interpModeY);
    vtkSetTricubicDerivCoeffs(fZ,gZ,&jl,&jm,fz,interpModeZ);
    }
  else
    {
    vtkSetTricubicInterpCoeffs(fX,&ll,&lm,fx,interpModeX);
    vtkSetTricubicInterpCoeffs(fY,&kl,&km,fy,interpModeY);
    vtkSetTricubicInterpCoeffs(fZ,&jl,&jm,fz,interpModeZ);
    }

  // Here is the tricubic interpolation
  // (or cubic-cubic-linear, or cubic-nearest-cubic, etc)
  float vY[3],vZ[3];
  displacement[0] = 0;
  displacement[1] = 0;
  displacement[2] = 0;
  for (int j = jl; j < jm; j++)
    {
    T *gridPtr1 = gridPtr + factZ[j];
    vZ[0] = 0;
    vZ[1] = 0;
    vZ[2] = 0;
    for (int k = kl; k < km; k++)
      {
      T *gridPtr2 = gridPtr1 + factY[k];
      vY[0] = 0;
      vY[1] = 0;
      vY[2] = 0;
      if (!derivatives)
	{
	for (int l = ll; l < lm; l++)
	  {
	  T *gridPtr3 = gridPtr2 + factX[l];
	  float f = fX[l];
	  vY[0] += gridPtr3[0] * f;
	  vY[1] += gridPtr3[1] * f;
	  vY[2] += gridPtr3[2] * f;
	  }
	}
      else
	{
	for (int l = ll; l < lm; l++)
	  {
	  T *gridPtr3 = gridPtr2 + factX[l];
	  float f = fX[l];
	  float gff = gX[l]*fY[k]*fZ[j];
	  float fgf = fX[l]*gY[k]*fZ[j];
	  float ffg = fX[l]*fY[k]*gZ[j];
	  float inVal = gridPtr3[0];
	  vY[0] += inVal * f;
	  derivatives[0][0] += inVal * gff;
	  derivatives[0][1] += inVal * fgf;
	  derivatives[0][2] += inVal * ffg;
	  inVal = gridPtr3[1];
	  vY[1] += inVal * f;
	  derivatives[1][0] += inVal * gff;
	  derivatives[1][1] += inVal * fgf;
	  derivatives[1][2] += inVal * ffg;
	  inVal = gridPtr3[2];
	  vY[2] += inVal * f;
	  derivatives[2][0] += inVal * gff;
	  derivatives[2][1] += inVal * fgf;
	  derivatives[2][2] += inVal * ffg;
	  }
	}
        vZ[0] += vY[0]*fY[k];
        vZ[1] += vY[1]*fY[k];
        vZ[2] += vY[2]*fY[k];
      }
    displacement[0] += vZ[0]*fZ[j];
    displacement[1] += vZ[1]*fZ[j];
    displacement[2] += vZ[2]*fZ[j];
    }
}

static void vtkTricubicInterpolation(float point[3],
				     float displacement[3], 
				     float derivatives[3][3],
				     void *gridPtr, int gridType, 
				     int gridExt[6], int gridInc[3])
{
  int factX[4],factY[4],factZ[4];

  // change point into integer plus fraction
  float f[3];
  int floorX = vtkGridFloor(point[0],f[0]);
  int floorY = vtkGridFloor(point[1],f[1]);
  int floorZ = vtkGridFloor(point[2],f[2]);

  int gridId0[3];
  gridId0[0] = floorX - gridExt[0];
  gridId0[1] = floorY - gridExt[2];
  gridId0[2] = floorZ - gridExt[4];

  int gridId1[3];
  gridId1[0] = gridId0[0] + 1;
  gridId1[1] = gridId0[1] + 1;
  gridId1[2] = gridId0[2] + 1;

  int ext[3];
  ext[0] = gridExt[1] - gridExt[0];
  ext[1] = gridExt[3] - gridExt[2];
  ext[2] = gridExt[5] - gridExt[4];

  // the doInterpX,Y,Z variables are 0 if interpolation
  // does not have to be done in the specified direction.
  int doInterp[3];
  doInterp[0] = 1;
  doInterp[1] = 1;
  doInterp[2] = 1;

  // do bounds check, most points will be inside so optimize for that
  if ((gridId0[0] | (ext[0] - gridId1[0]) |
       gridId0[1] | (ext[1] - gridId1[1]) |
       gridId0[2] | (ext[2] - gridId1[2])) < 0)
    {
    for (int i = 0; i < 3; i++)
      {
      if (gridId0[i] < 0)
        {
        gridId0[i] = 0;
        gridId1[i] = 0;
	doInterp[i] = 0;
        f[i] = 0;
        }
      else if (gridId1[i] > ext[i])
        {
        gridId0[i] = ext[i];
        gridId1[i] = ext[i];
	doInterp[i] = 0;
        f[i] = 0;
        }
      }
    }

  // do tricubic interpolation
  
  for (int i = 0; i < 4; i++)
    {
    factX[i] = (gridId0[0]-1+i)*gridInc[0];
    factY[i] = (gridId0[1]-1+i)*gridInc[1];
    factZ[i] = (gridId0[2]-1+i)*gridInc[2];
    }

  // depending on whether we are at the edge of the 
  // input extent, choose the appropriate interpolation
  // method to use

  int interpModeX = ((gridId0[0] > 0) << 2) + 
                    ((gridId1[0] < ext[0]) << 1) +
                    doInterp[0];
  int interpModeY = ((gridId0[1] > 0) << 2) + 
                    ((gridId1[1] < ext[1]) << 1) +
                    doInterp[1];
  int interpModeZ = ((gridId0[2] > 0) << 2) + 
	            ((gridId1[2] < ext[2]) << 1) +
		    doInterp[2];

  switch (gridType)
    {
    case VTK_CHAR:
      vtkCubicHelper(displacement, derivatives, f[0], f[1], f[2],
		     (char *)gridPtr,
		     interpModeX, interpModeY, interpModeZ,
		     factX, factY, factZ);
      break;
    case VTK_UNSIGNED_CHAR:
      vtkCubicHelper(displacement, derivatives, f[0], f[1], f[2],
		     (unsigned char *)gridPtr,
		     interpModeX, interpModeY, interpModeZ,
		     factX, factY, factZ);
      break;
    case VTK_SHORT:
      vtkCubicHelper(displacement, derivatives, f[0], f[1], f[2],
		     (short *)gridPtr,
		     interpModeX, interpModeY, interpModeZ,
		     factX, factY, factZ);
      break;
    case VTK_UNSIGNED_SHORT:
      vtkCubicHelper(displacement, derivatives, f[0], f[1], f[2],
		     (unsigned short *)gridPtr,
		     interpModeX, interpModeY, interpModeZ,
		     factX, factY, factZ);
      break;
    case VTK_FLOAT:
      vtkCubicHelper(displacement, derivatives, f[0], f[1], f[2],
		     (float *)gridPtr,
		     interpModeX, interpModeY, interpModeZ,
		     factX, factY, factZ);
      break;
    }
}		  

//----------------------------------------------------------------------------
vtkGridTransform::vtkGridTransform()
{
  this->InterpolationMode = VTK_GRID_LINEAR;
  this->InterpolationFunction = &vtkTrilinearInterpolation;
  this->DisplacementGrid = NULL;
  this->DisplacementScale = 1.0;
  this->DisplacementShift = 0.0;
  // the grid warp has a fairly large tolerance
  this->InverseTolerance = 0.01;
}

//----------------------------------------------------------------------------
vtkGridTransform::~vtkGridTransform()
{
  this->SetDisplacementGrid(NULL);
}

//----------------------------------------------------------------------------
void vtkGridTransform::PrintSelf(ostream& os, vtkIndent indent)
{
  vtkWarpTransform::PrintSelf(os,indent);

  os << indent << "InterpolationMode: " 
     << this->GetInterpolationModeAsString() << "\n";
  os << indent << "DisplacementScale: " << this->DisplacementScale << "\n";
  os << indent << "DisplacementShift: " << this->DisplacementShift << "\n";
  os << indent << "DisplacementGrid: " << this->DisplacementGrid << "\n";
  if(this->DisplacementGrid)
    {
    this->DisplacementGrid->PrintSelf(os,indent.GetNextIndent());
    }
}

//----------------------------------------------------------------------------
// need to check the input image data to determine MTime
unsigned long vtkGridTransform::GetMTime()
{
  unsigned long mtime,result;
  result = vtkWarpTransform::GetMTime();
  if (this->DisplacementGrid)
    {
    this->DisplacementGrid->UpdateInformation();

    mtime = this->DisplacementGrid->GetPipelineMTime();
    result = ( mtime > result ? mtime : result );    

    mtime = this->DisplacementGrid->GetMTime();
    result = ( mtime > result ? mtime : result );
    }

  return result;
}

//----------------------------------------------------------------------------
void vtkGridTransform::SetInterpolationMode(int mode)
{
  if (mode == this->InterpolationMode)
    {
    return;
    }
  this->InterpolationMode = mode;
  switch(mode)
    {
    case VTK_GRID_NEAREST:
      this->InterpolationFunction = &vtkNearestNeighborInterpolation;
      break;
    case VTK_GRID_LINEAR:
      this->InterpolationFunction = &vtkTrilinearInterpolation;
      break;
    case VTK_GRID_CUBIC:
      this->InterpolationFunction = &vtkTricubicInterpolation;
      break;
    default:
      vtkErrorMacro( << "SetInterpolationMode: Illegal interpolation mode");
    }
  this->Modified();
}

//----------------------------------------------------------------------------
void vtkGridTransform::ForwardTransformPoint(const float inPoint[3], 
					     float outPoint[3])
{
  if (this->DisplacementGrid == NULL)
    {
    outPoint[0] = inPoint[0]; 
    outPoint[1] = inPoint[1]; 
    outPoint[2] = inPoint[2]; 
    return;
    }

  vtkImageData *grid = this->DisplacementGrid;
  void *gridPtr = grid->GetScalarPointer();
  int gridType = grid->GetScalarType();

  float *spacing = grid->GetSpacing();
  float *origin = grid->GetOrigin();
  int *extent = grid->GetExtent();
  int *increments = grid->GetIncrements();

  float scale = this->DisplacementScale;
  float shift = this->DisplacementShift;

  float point[3];
  float displacement[3];

  // Convert the inPoint to i,j,k indices into the deformation grid 
  // plus fractions
  point[0] = (inPoint[0] - origin[0])/spacing[0];
  point[1] = (inPoint[1] - origin[1])/spacing[1];
  point[2] = (inPoint[2] - origin[2])/spacing[2];

  this->InterpolationFunction(point,displacement,NULL,
			      gridPtr,gridType,extent,increments);

  outPoint[0] = inPoint[0] + (displacement[0]*scale + shift);
  outPoint[1] = inPoint[1] + (displacement[1]*scale + shift);
  outPoint[2] = inPoint[2] + (displacement[2]*scale + shift);
}

//----------------------------------------------------------------------------
// convert double to float
void vtkGridTransform::ForwardTransformPoint(const double point[3], 
					     double output[3])
{
  float fpoint[3];
  fpoint[0] = point[0]; 
  fpoint[1] = point[1]; 
  fpoint[2] = point[2];

  this->ForwardTransformPoint(fpoint,fpoint);
 
  output[0] = fpoint[0]; 
  output[1] = fpoint[1]; 
  output[2] = fpoint[2];
}

//----------------------------------------------------------------------------
// calculate the derivative of the grid transform: only cubic interpolation
// provides well-behaved derivative so we always use that.
void vtkGridTransform::ForwardTransformDerivative(const float inPoint[3],
						  float outPoint[3],
						  float derivative[3][3])
{
  if (this->DisplacementGrid == NULL)
    {
    outPoint[0] = inPoint[0]; 
    outPoint[1] = inPoint[1]; 
    outPoint[2] = inPoint[2]; 
    vtkMath::Identity3x3(derivative);
    return;
    }

  vtkImageData *grid = this->DisplacementGrid;
  void *gridPtr = grid->GetScalarPointer();
  int gridType = grid->GetScalarType();

  float *spacing = grid->GetSpacing();
  float *origin = grid->GetOrigin();
  int *extent = grid->GetExtent();
  int *increments = grid->GetIncrements();

  float scale = this->DisplacementScale;
  float shift = this->DisplacementShift;

  float point[3];
  float displacement[3];

  // convert the inPoint to i,j,k indices plus fractions
  point[0] = (inPoint[0] - origin[0])/spacing[0];
  point[1] = (inPoint[1] - origin[1])/spacing[1];
  point[2] = (inPoint[2] - origin[2])/spacing[2];

  this->InterpolationFunction(point,displacement,derivative,
			      gridPtr,gridType,extent,increments);

  for (int i = 0; i < 3; i++)
    {
    derivative[i][0] = derivative[i][0]*scale/spacing[0];
    derivative[i][1] = derivative[i][1]*scale/spacing[1];
    derivative[i][2] = derivative[i][2]*scale/spacing[2];
    derivative[i][i] += 1.0f;
    }

  outPoint[0] = inPoint[0] + (displacement[0]*scale + shift);
  outPoint[1] = inPoint[1] + (displacement[1]*scale + shift);
  outPoint[2] = inPoint[2] + (displacement[2]*scale + shift);
}  

//----------------------------------------------------------------------------
// convert double to float
void vtkGridTransform::ForwardTransformDerivative(const double point[3],
						  double output[3],
						  double derivative[3][3])
{
  float fpoint[3];
  float fderivative[3][3];
  fpoint[0] = point[0];
  fpoint[1] = point[1];
  fpoint[2] = point[2];

  this->ForwardTransformDerivative(fpoint,fpoint,fderivative);

  for (int i = 0; i < 3; i++)
    {
    derivative[i][0] = fderivative[i][0];
    derivative[i][1] = fderivative[i][1];
    derivative[i][2] = fderivative[i][2];
    output[i] = fpoint[i];
    }
}

//----------------------------------------------------------------------------
// We use Newton's method to iteratively invert the transformation.  
// This is actally quite robust as long as the Jacobian matrix is never
// singular.
// Note that this is similar to vtkWarpTransform::InverseTransformPoint()
// but has been optimized specifically for grid transforms.
void vtkGridTransform::InverseTransformDerivative(const float inPoint[3], 
						  float outPoint[3],
						  float derivative[3][3])
{
  if (this->DisplacementGrid == NULL)
    {
    outPoint[0] = inPoint[0]; 
    outPoint[1] = inPoint[1]; 
    outPoint[2] = inPoint[2]; 
    return;
    }

  vtkImageData *grid = this->DisplacementGrid;
  void *gridPtr = grid->GetScalarPointer();
  int gridType = grid->GetScalarType();

  float *spacing = grid->GetSpacing();
  float *origin = grid->GetOrigin();
  int *extent = grid->GetExtent();
  int *increments = grid->GetIncrements();

  float invSpacing[3];
  invSpacing[0] = 1.0f/spacing[0];
  invSpacing[1] = 1.0f/spacing[1];
  invSpacing[2] = 1.0f/spacing[2];

  float shift = this->DisplacementShift;
  float scale = this->DisplacementScale;

  // convert the inPoint to i,j,k indices plus fractions
  float point[3];
  point[0] = (inPoint[0] - origin[0])*invSpacing[0];
  point[1] = (inPoint[1] - origin[1])*invSpacing[1];
  point[2] = (inPoint[2] - origin[2])*invSpacing[2];

  // first guess at inverse point, just subtract displacement
  // (the inverse point is given in i,j,k indices plus fractions)
  float deltaI[3];
  this->InterpolationFunction(point, deltaI, NULL,
			      gridPtr, gridType, extent, increments);
  float inverse[3], lastInverse[3];
  inverse[0] = point[0] - (deltaI[0]*scale + shift)*invSpacing[0];
  inverse[1] = point[1] - (deltaI[1]*scale + shift)*invSpacing[1];
  inverse[2] = point[2] - (deltaI[2]*scale + shift)*invSpacing[2];

  // put the inverse point back through the transform
  float deltaP[3], gradient[3];
  this->InterpolationFunction(inverse, deltaP, derivative,
			      gridPtr, gridType, extent, increments);

  // convert displacement 
  deltaP[0] = (inverse[0] - point[0])*spacing[0] + deltaP[0]*scale + shift;
  deltaP[1] = (inverse[1] - point[1])*spacing[1] + deltaP[1]*scale + shift;
  deltaP[2] = (inverse[2] - point[2])*spacing[2] + deltaP[2]*scale + shift;

  float lastErrorSquared;
  float errorSquared = deltaP[0]*deltaP[0] +
                       deltaP[1]*deltaP[1] +
                       deltaP[2]*deltaP[2];

  float toleranceSquared = this->InverseTolerance*
			   this->InverseTolerance;

  // do a maximum 500 iterations, usually less than 10 are required
  int n = this->InverseIterations;
  int i;
  for (i = 0; i < n && errorSquared > toleranceSquared; i++)
    {
    // save previous error
    lastErrorSquared = errorSquared;

    // convert derivative
    for (int j = 0; j < 3; j++)
      {
      derivative[j][0] = derivative[j][0]*scale*invSpacing[0];
      derivative[j][1] = derivative[j][1]*scale*invSpacing[1];
      derivative[j][2] = derivative[j][2]*scale*invSpacing[2];
      derivative[j][j] += 1.0f;
      }

    // here is the critical step in Newton's method
    vtkMath::LinearSolve3x3(derivative,deltaP,deltaI);

    // save the inverse
    lastInverse[0] = inverse[0];
    lastInverse[1] = inverse[1];
    lastInverse[2] = inverse[2];

    // calculate the gradient of errorSquared
    gradient[0] = deltaP[0]*derivative[0][0]*2;
    gradient[1] = deltaP[1]*derivative[1][1]*2;
    gradient[2] = deltaP[2]*derivative[2][2]*2;

    // calculate the new inverse
    inverse[0] -= deltaI[0]*invSpacing[0];
    inverse[1] -= deltaI[1]*invSpacing[1];
    inverse[2] -= deltaI[2]*invSpacing[2];

    // put the inverse point back through the transform
    this->InterpolationFunction(inverse, deltaP, derivative,
				gridPtr, gridType, extent, increments);

    // convert displacement 
    deltaP[0] = (inverse[0] - point[0])*spacing[0] + deltaP[0]*scale + shift;
    deltaP[1] = (inverse[1] - point[1])*spacing[1] + deltaP[1]*scale + shift;
    deltaP[2] = (inverse[2] - point[2])*spacing[2] + deltaP[2]*scale + shift;

    // add errors for each dimension
    errorSquared = deltaP[0]*deltaP[0] +
                   deltaP[1]*deltaP[1] +
                   deltaP[2]*deltaP[2];

    if (errorSquared > lastErrorSquared)
      { // the error is increasing, backtrack 
	// see Numerical Recipes 9.7 for rationale

      // derivative of errorSquared for lastError
      float lastErrorSquaredD = (gradient[0]*deltaI[0] +
				 gradient[1]*deltaI[1] +
				 gradient[2]*deltaI[2]);

      // quadratic approximation to find best fractional distance
      float f = lastErrorSquaredD/
	  (2*(errorSquared-lastErrorSquared-lastErrorSquaredD));

      if (f < 0.1)
	{
	f = 0.1;
	}
      if (f > 0.5)
	{
	f = 0.5;
	}

      // calculate inverse using fractional distance
      inverse[0] = lastInverse[0] - f*deltaI[0]*invSpacing[0];
      inverse[1] = lastInverse[1] - f*deltaI[1]*invSpacing[1];
      inverse[2] = lastInverse[2] - f*deltaI[2]*invSpacing[2];

      // put the inverse point back through the transform
      this->InterpolationFunction(inverse, deltaP, derivative,
				  gridPtr, gridType, extent, increments);
      
      // convert displacement 
      deltaP[0] = (inverse[0] - point[0])*spacing[0] + deltaP[0]*scale + shift;
      deltaP[1] = (inverse[1] - point[1])*spacing[1] + deltaP[1]*scale + shift;
      deltaP[2] = (inverse[2] - point[2])*spacing[2] + deltaP[2]*scale + shift;

      // add errors for each dimension
      errorSquared = deltaP[0]*deltaP[0] +
	             deltaP[1]*deltaP[1] +
                     deltaP[2]*deltaP[2];
      }
    }

  // convert derivative
  for (int j = 0; j < 3; j++)
    {
    derivative[j][0] = derivative[j][0]*scale*invSpacing[0];
    derivative[j][1] = derivative[j][1]*scale*invSpacing[1];
    derivative[j][2] = derivative[j][2]*scale*invSpacing[2];
    derivative[j][j] += 1.0f;
    }

  // convert point
  outPoint[0] = inverse[0]*spacing[0] + origin[0];
  outPoint[1] = inverse[1]*spacing[1] + origin[1];
  outPoint[2] = inverse[2]*spacing[2] + origin[2];

  vtkDebugMacro("Inverse Iterations: " << (i+1));

  if (i >= this->InverseIterations)
    {
    vtkWarningMacro("InverseTransformPoint: no convergence (" <<
		    point[0] << ", " << point[1] << ", " << point[2] << 
		    ") error = " << sqrt(errorSquared) << " after " <<
		    i << " iterations.");
    }

}

//----------------------------------------------------------------------------
// convert double to float and back again
void vtkGridTransform::InverseTransformDerivative(const double point[3], 
						  double output[3],
						  double derivative[3][3])
{
  float fpoint[3];
  float fderivative[3][3];
  fpoint[0] = point[0]; 
  fpoint[1] = point[1]; 
  fpoint[2] = point[2];

  this->InverseTransformDerivative(fpoint,fpoint,fderivative);
 
  for (int i = 0; i < 3; i++)
    {
    output[i] = fpoint[i]; 
    derivative[i][0] = fderivative[i][0];
    derivative[i][1] = fderivative[i][1];
    derivative[i][2] = fderivative[i][2];
    }
}

//----------------------------------------------------------------------------
void vtkGridTransform::InverseTransformPoint(const float point[3], 
					     float output[3])
{
  // the derivative won't be used, but it is required for Newton's method
  float derivative[3][3];
  this->InverseTransformDerivative(point,output,derivative);
}

//----------------------------------------------------------------------------
// convert double to float and back again
void vtkGridTransform::InverseTransformPoint(const double point[3], 
					     double output[3])
{
  float fpoint[3];
  float fderivative[3][3];
  fpoint[0] = point[0]; 
  fpoint[1] = point[1]; 
  fpoint[2] = point[2];

  this->InverseTransformDerivative(fpoint,fpoint,fderivative);
 
  output[0] = fpoint[0]; 
  output[1] = fpoint[1]; 
  output[2] = fpoint[2];
}

//----------------------------------------------------------------------------
void vtkGridTransform::InternalDeepCopy(vtkAbstractTransform *transform)
{
  vtkGridTransform *gridTransform = (vtkGridTransform *)transform;

  this->SetInverseTolerance(gridTransform->InverseTolerance);
  this->SetInverseIterations(gridTransform->InverseIterations);
  this->SetInterpolationMode(gridTransform->InterpolationMode);
  this->InterpolationFunction = gridTransform->InterpolationFunction;
  this->SetDisplacementScale(gridTransform->DisplacementScale);
  this->SetDisplacementGrid(gridTransform->DisplacementGrid);
  this->SetDisplacementShift(gridTransform->DisplacementShift);
  this->SetDisplacementScale(gridTransform->DisplacementScale);

  if (this->InverseFlag != gridTransform->InverseFlag)
    {
    this->InverseFlag = gridTransform->InverseFlag;
    this->Modified();
    }
}

//----------------------------------------------------------------------------
void vtkGridTransform::InternalUpdate()
{
  vtkImageData *grid = this->DisplacementGrid;

  if (grid == 0)
    {
    return;
    }

  grid->UpdateInformation();

  if (grid->GetNumberOfScalarComponents() != 3)
    {
    vtkErrorMacro(<< "TransformPoint: displacement grid must have 3 components");
    return;
    }
  if (grid->GetScalarType() != VTK_CHAR &&
      grid->GetScalarType() != VTK_UNSIGNED_CHAR &&
      grid->GetScalarType() != VTK_SHORT &&
      grid->GetScalarType() != VTK_UNSIGNED_SHORT &&
      grid->GetScalarType() != VTK_FLOAT)
    {
    vtkErrorMacro(<< "TransformPoint: displacement grid is of unsupported numerical type");
    return;
    }
 
  grid->SetUpdateExtent(grid->GetWholeExtent());
  grid->Update();
}

//----------------------------------------------------------------------------
vtkAbstractTransform *vtkGridTransform::MakeTransform()
{
  return vtkGridTransform::New();
}
