import * as THREE from "three";

/**
 * SplatBuffer: Container for splat data from a single scene/file and capable of (mediocre) compression.
 */
export class SplatBuffer {
  static CenterComponentCount = 3;
  static ScaleComponentCount = 3;
  static RotationComponentCount = 4;
  static ColorComponentCount = 4;
  static MuTComponentCount = 1;
  static ScaleTComponentCount = 1;

  static CompressionLevels = {
    0: {
      BytesPerCenter: 12,
      BytesPerScale: 12,
      BytesPerColor: 4,
      BytesPerRotation: 16,
      BytesPerScaleT: 4,
      BytesPerMuT: 4,
      ScaleRange: 1,
    },
    1: {
      BytesPerCenter: 6,
      BytesPerScale: 6,
      BytesPerColor: 4,
      BytesPerRotation: 8,
      BytesPerScaleT: 2,
      BytesPerMuT: 2,
      ScaleRange: 32767,
    },
  };

  static CovarianceSizeFloats = 6;
  static CovarianceSizeBytes = 24;
  static Covariance4DSizeFloats = 4;

  static HeaderSizeBytes = 1024;

  constructor(bufferData, is4D = false) {
    this.is4D = is4D;
    this.headerBufferData = new ArrayBuffer(SplatBuffer.HeaderSizeBytes);
    this.headerArrayUint8 = new Uint8Array(this.headerBufferData);
    this.headerArrayUint32 = new Uint32Array(this.headerBufferData);
    this.headerArrayFloat32 = new Float32Array(this.headerBufferData);
    this.headerArrayUint8.set(
      new Uint8Array(bufferData, 0, SplatBuffer.HeaderSizeBytes)
    );
    this.versionMajor = this.headerArrayUint8[0];
    this.versionMinor = this.headerArrayUint8[1];
    this.headerExtraK = this.headerArrayUint8[2];
    this.compressionLevel = this.headerArrayUint8[3];
    this.splatCount = this.headerArrayUint32[1];
    this.bucketSize = this.headerArrayUint32[2];
    this.bucketCount = this.headerArrayUint32[3];
    this.bucketBlockSize = this.headerArrayFloat32[4];
    this.halfBucketBlockSize = this.bucketBlockSize / 2.0;
    this.bytesPerBucket = this.headerArrayUint32[5];
    this.compressionScaleRange =
      this.headerArrayUint32[6] ||
      SplatBuffer.CompressionLevels[this.compressionLevel].ScaleRange;
    this.compressionScaleFactor =
      this.halfBucketBlockSize / this.compressionScaleRange;

    const dataBufferSizeBytes =
      bufferData.byteLength - SplatBuffer.HeaderSizeBytes;
    this.splatBufferData = new ArrayBuffer(dataBufferSizeBytes);
    new Uint8Array(this.splatBufferData).set(
      new Uint8Array(
        bufferData,
        SplatBuffer.HeaderSizeBytes,
        dataBufferSizeBytes
      )
    );

    this.bytesPerCenter =
      SplatBuffer.CompressionLevels[this.compressionLevel].BytesPerCenter;
    this.bytesPerScale =
      SplatBuffer.CompressionLevels[this.compressionLevel].BytesPerScale;
    this.bytesPerColor =
      SplatBuffer.CompressionLevels[this.compressionLevel].BytesPerColor;
    this.bytesPerRotation =
      SplatBuffer.CompressionLevels[this.compressionLevel].BytesPerRotation;
    this.bytesPerMuT =
      SplatBuffer.CompressionLevels[this.compressionLevel].BytesPerMuT;
    this.bytesPerScaleT =
      SplatBuffer.CompressionLevels[this.compressionLevel].BytesPerScaleT;

    this.bytesPerSplat =
      this.bytesPerCenter +
      this.bytesPerScale +
      this.bytesPerColor +
      this.bytesPerRotation;

    if (is4D) {
      this.bytesPerSplat += this.bytesPerRotation; // Right rotation
      this.bytesPerSplat += this.bytesPerMuT;
      this.bytesPerSplat += this.bytesPerScaleT;
    }

    this.linkBufferArrays();
  }

  linkBufferArrays() {
    let FloatArray = this.compressionLevel === 0 ? Float32Array : Uint16Array;
    this.centerArray = new FloatArray(
      this.splatBufferData,
      0,
      this.splatCount * SplatBuffer.CenterComponentCount
    );
    this.scaleArray = new FloatArray(
      this.splatBufferData,
      this.bytesPerCenter * this.splatCount,
      this.splatCount * SplatBuffer.ScaleComponentCount
    );
    this.colorArray = new Uint8Array(
      this.splatBufferData,
      (this.bytesPerCenter + this.bytesPerScale) * this.splatCount,
      this.splatCount * SplatBuffer.ColorComponentCount
    );
    this.rotationArray = new FloatArray(
      this.splatBufferData,
      (this.bytesPerCenter + this.bytesPerScale + this.bytesPerColor) *
        this.splatCount,
      this.splatCount * SplatBuffer.RotationComponentCount
    );
    if (this.is4D) {
      this.rotationRightArray = new FloatArray(
        this.splatBufferData,
        (this.bytesPerCenter +
          this.bytesPerScale +
          this.bytesPerColor +
          this.bytesPerRotation) *
          this.splatCount,
        this.splatCount * SplatBuffer.RotationComponentCount
      );

      this.muTArray = new FloatArray(
        this.splatBufferData,
        (this.bytesPerCenter +
          this.bytesPerScale +
          this.bytesPerColor +
          this.bytesPerRotation +
          this.bytesPerRotation) /*right rot*/ *
          this.splatCount,
        this.splatCount * SplatBuffer.MuTComponentCount
      );

      this.scaleTArray = new FloatArray(
        this.splatBufferData,
        (this.bytesPerCenter +
          this.bytesPerScale +
          this.bytesPerColor +
          this.bytesPerRotation +
          this.bytesPerRotation /*right rot*/ +
          this.bytesPerMuT) *
          this.splatCount,
        this.splatCount * SplatBuffer.ScaleTComponentCount
      );
    }
    this.bucketsBase = this.splatCount * this.bytesPerSplat;
  }

  fbf(f) {
    if (this.compressionLevel === 0) {
      return f;
    } else {
      return THREE.DataUtils.fromHalfFloat(f);
    }
  }

  getHeaderBufferData() {
    return this.headerBufferData;
  }

  getSplatBufferData() {
    return this.splatBufferData;
  }

  getSplatCount() {
    return this.splatCount;
  }

  fill4DData(covarRow, mu_tArray, destOffset) {
    if (this.is4D) {
      const splatCount = this.splatCount;
      if (!mu_tArray) {
        console.log("missing mu_tarray");
      }
      for (let i = 0; i < splatCount; i++) {
        const muSrcBase = i * SplatBuffer.MuTComponentCount;
        const muDestBase = (i + destOffset) * SplatBuffer.MuTComponentCount;
        const muT = this.muTArray[muSrcBase];
        if (mu_tArray) {
          mu_tArray[muDestBase] = muT;
        }

        // Calculate 4d Covariance
        const rotation = new THREE.Quaternion();
        const rotationBase = i * SplatBuffer.RotationComponentCount;
        const scaleBase = i * SplatBuffer.ScaleComponentCount;
        const rotationRight = new THREE.Quaternion();
        const rotation4DMatrix = new THREE.Matrix4();
        const rotationLeftMatrix = new THREE.Matrix4();
        const rotationRightMatrix = new THREE.Matrix4();
        const scale4D = new THREE.Matrix4();
        const covar4D = new THREE.Matrix4();
        const covar4DTranspose = new THREE.Matrix4();
        const covar4DFinal = new THREE.Matrix4();

        rotation.set(
          this.fbf(this.rotationArray[rotationBase + 1]),
          this.fbf(this.rotationArray[rotationBase + 2]),
          this.fbf(this.rotationArray[rotationBase + 3]),
          this.fbf(this.rotationArray[rotationBase])
        );
        scale4D.set(
          this.fbf(this.scaleArray[scaleBase]),
          0,
          0,
          0,
          0,
          this.fbf(this.scaleArray[scaleBase + 1]),
          0,
          0,
          0,
          0,
          this.fbf(this.scaleArray[scaleBase + 2]),
          0,
          0,
          0,
          0,
          this.fbf(this.scaleTArray[i])
        );
        rotationRight.set(
          this.fbf(this.rotationRightArray[rotationBase + 1]),
          this.fbf(this.rotationRightArray[rotationBase + 2]),
          this.fbf(this.rotationRightArray[rotationBase + 3]),
          this.fbf(this.rotationRightArray[rotationBase])
        );
        rotationRightMatrix.fromArray(
          this.rightQuatToMatrix(rotationRight).elements
        );
        rotationLeftMatrix.fromArray(this.leftQuatToMatrix(rotation).elements);
        rotation4DMatrix.multiplyMatrices(
          rotationRightMatrix,
          rotationLeftMatrix
        );
        covar4D.multiplyMatrices(scale4D, rotation4DMatrix);
        covar4DTranspose.copy(covar4D).transpose(); // final result in covar4d transpose
        covar4DFinal.multiplyMatrices(covar4DTranspose, covar4D);

        const covarDestBase = (i + destOffset) * 4;
        covarRow[covarDestBase] = covar4DFinal.elements[12];
        covarRow[covarDestBase + 1] = covar4DFinal.elements[13];
        covarRow[covarDestBase + 2] = covar4DFinal.elements[14];
        if (covar4DFinal.elements[15] == 0) {
          covarRow[covarDestBase + 3] = 1;
        } else {
          covarRow[covarDestBase + 3] = covar4DFinal.elements[15];
        }
      }
    } else {
      console.warn("fill4DData called without is4D");
    }
  }

  //TODO(@dlazares): update this for 4d time based
  getSplatCenter(index, outCenter, transform) {
    let bucket = [0, 0, 0];
    const centerBase = index * SplatBuffer.CenterComponentCount;
    if (this.compressionLevel > 0) {
      const sf = this.compressionScaleFactor;
      const sr = this.compressionScaleRange;
      const bucketIndex = Math.floor(index / this.bucketSize);
      bucket = new Float32Array(
        this.splatBufferData,
        this.bucketsBase + bucketIndex * this.bytesPerBucket,
        3
      );
      outCenter.x = (this.centerArray[centerBase] - sr) * sf + bucket[0];
      outCenter.y = (this.centerArray[centerBase + 1] - sr) * sf + bucket[1];
      outCenter.z = (this.centerArray[centerBase + 2] - sr) * sf + bucket[2];
    } else {
      outCenter.x = this.centerArray[centerBase];
      outCenter.y = this.centerArray[centerBase + 1];
      outCenter.z = this.centerArray[centerBase + 2];
    }
    if (transform) outCenter.applyMatrix4(transform);
  }

  getSplatScaleAndRotation = (function () {
    const scaleMatrix = new THREE.Matrix4();
    const rotationMatrix = new THREE.Matrix4();
    const tempMatrix = new THREE.Matrix4();
    const tempPosition = new THREE.Vector3();

    return function (index, outScale, outRotation, transform) {
      const scaleBase = index * SplatBuffer.ScaleComponentCount;
      outScale.set(
        this.fbf(this.scaleArray[scaleBase]),
        this.fbf(this.scaleArray[scaleBase + 1]),
        this.fbf(this.scaleArray[scaleBase + 2])
      );
      const rotationBase = index * SplatBuffer.RotationComponentCount;
      outRotation.set(
        this.fbf(this.rotationArray[rotationBase + 1]),
        this.fbf(this.rotationArray[rotationBase + 2]),
        this.fbf(this.rotationArray[rotationBase + 3]),
        this.fbf(this.rotationArray[rotationBase])
      );
      if (transform) {
        scaleMatrix.makeScale(outScale.x, outScale.y, outScale.z);
        rotationMatrix.makeRotationFromQuaternion(outRotation);
        tempMatrix
          .copy(scaleMatrix)
          .multiply(rotationMatrix)
          .multiply(transform);
        tempMatrix.decompose(tempPosition, outRotation, outScale);
      }
    };
  })();

  getSplatColor(index, outColor, transform) {
    const colorBase = index * SplatBuffer.ColorComponentCount;
    outColor.set(
      this.colorArray[colorBase],
      this.colorArray[colorBase + 1],
      this.colorArray[colorBase + 2],
      this.colorArray[colorBase + 3]
    );
    // TODO: apply transform for spherical harmonics
  }

  fillSplatCenterArray(outCenterArray, destOffset, transform) {
    const splatCount = this.splatCount;
    let bucket = [0, 0, 0];
    const center = new THREE.Vector3();
    for (let i = 0; i < splatCount; i++) {
      const centerSrcBase = i * SplatBuffer.CenterComponentCount;
      const centerDestBase =
        (i + destOffset) * SplatBuffer.CenterComponentCount;
      if (this.compressionLevel > 0) {
        //console.warn("filsplat center compression level > 0");
        //TODO(@dlazares): compression for mu_t and 4d?
        const bucketIndex = Math.floor(i / this.bucketSize);
        bucket = new Float32Array(
          this.splatBufferData,
          this.bucketsBase + bucketIndex * this.bytesPerBucket,
          3
        );
        const sf = this.compressionScaleFactor;
        const sr = this.compressionScaleRange;
        center.x = (this.centerArray[centerSrcBase] - sr) * sf + bucket[0];
        center.y = (this.centerArray[centerSrcBase + 1] - sr) * sf + bucket[1];
        center.z = (this.centerArray[centerSrcBase + 2] - sr) * sf + bucket[2];
      } else {
        center.x = this.centerArray[centerSrcBase];
        center.y = this.centerArray[centerSrcBase + 1];
        center.z = this.centerArray[centerSrcBase + 2];
      }

      if (transform && !this.is4D) {
        center.applyMatrix4(transform);
      }
      outCenterArray[centerDestBase] = center.x;
      outCenterArray[centerDestBase + 1] = center.y;
      outCenterArray[centerDestBase + 2] = center.z;
    }
  }

  // Turns a left isoclinic 4d rotation quaternion into rotation matrix
  // https://en.wikipedia.org/wiki/Rotations_in_4-dimensional_Euclidean_space#Isoclinic_decomposition
  leftQuatToMatrix(quat) {
    const a = quat.w; // Note that threejs is in xyzw order but paper is wxyz order so a is w here
    const b = quat.x;
    const c = quat.y;
    const d = quat.z;

    // Create a new THREE.Matrix4 object
    const matrix = new THREE.Matrix4();

    // Set the elements in row-major order!!
    matrix.set(
      a,
      -b,
      -c,
      -d, // Column 1
      b,
      a,
      -d,
      c, // Column 2
      c,
      d,
      a,
      -b, // Column 3
      d,
      -c,
      b,
      a // Column 4
    );
    matrix.transpose();

    return matrix;
  }

  // Turns a right isoclinic 4d rotation quaternion into rotation matrix
  // https://en.wikipedia.org/wiki/Rotations_in_4-dimensional_Euclidean_space#Isoclinic_decomposition
  rightQuatToMatrix(quat) {
    const p = quat.w;
    const q = quat.x;
    const r = quat.y;
    const s = quat.z;

    // Create a new THREE.Matrix4 object
    const matrix = new THREE.Matrix4();

    // Set the elements in column-major order
    matrix.set(
      p,
      q,
      r,
      s, // Column 1
      -q,
      p,
      -s,
      r, // Column 2
      -r,
      s,
      p,
      -q, // Column 3
      -s,
      -r,
      q,
      p // Column 4
    );
    matrix.transpose();

    return matrix;
  }

  fillSplatCovarianceArray(
    covarianceArray,
    destOffset,
    transform,
    covariances4DArray,
    mu_tArray = undefined
  ) {
    const splatCount = this.splatCount;

    const scale = new THREE.Vector3();
    const rotation = new THREE.Quaternion();
    const rotationMatrix = new THREE.Matrix3();
    const scaleMatrix = new THREE.Matrix3();
    const covarianceMatrix = new THREE.Matrix3();
    const transformedCovariance = new THREE.Matrix3();
    const transform3x3 = new THREE.Matrix3();
    const transform3x3Transpose = new THREE.Matrix3();
    const tempMatrix4 = new THREE.Matrix4();

    for (let i = 0; i < splatCount; i++) {
      const rotationBase = i * SplatBuffer.RotationComponentCount;
      rotation.set(
        this.fbf(this.rotationArray[rotationBase + 1]),
        this.fbf(this.rotationArray[rotationBase + 2]),
        this.fbf(this.rotationArray[rotationBase + 3]),
        this.fbf(this.rotationArray[rotationBase])
      );
      if (this.is4D) {
        const muSrcBase = i * SplatBuffer.MuTComponentCount;
        const muDestBase = (i + destOffset) * SplatBuffer.MuTComponentCount;
        const muT = this.muTArray[muSrcBase];
        if (mu_tArray) {
          mu_tArray[muDestBase] = muT;
        }
        // Calculate 4d Covariance
        const scaleBase = i * SplatBuffer.ScaleComponentCount;
        const rotationRight = new THREE.Quaternion();
        const rotation4DMatrix = new THREE.Matrix4();
        const rotationLeftMatrix = new THREE.Matrix4();
        const rotationRightMatrix = new THREE.Matrix4();
        const scale4D = new THREE.Matrix4();
        const covar4D = new THREE.Matrix4();
        const covar4DTranspose = new THREE.Matrix4();
        const covar4DFinal = new THREE.Matrix4();
        scale4D.set(
          this.fbf(this.scaleArray[scaleBase]),
          0,
          0,
          0,
          0,
          this.fbf(this.scaleArray[scaleBase + 1]),
          0,
          0,
          0,
          0,
          this.fbf(this.scaleArray[scaleBase + 2]),
          0,
          0,
          0,
          0,
          this.fbf(this.scaleTArray[i])
        );
        rotationRight.set(
          this.fbf(this.rotationRightArray[rotationBase + 1]),
          this.fbf(this.rotationRightArray[rotationBase + 2]),
          this.fbf(this.rotationRightArray[rotationBase + 3]),
          this.fbf(this.rotationRightArray[rotationBase])
        );
        rotationRightMatrix.fromArray(
          this.rightQuatToMatrix(rotationRight).elements
        );
        rotationLeftMatrix.fromArray(this.leftQuatToMatrix(rotation).elements);
        rotation4DMatrix.multiplyMatrices(
          rotationRightMatrix,
          rotationLeftMatrix
        );
        covar4D.multiplyMatrices(scale4D, rotation4DMatrix);
        covar4DTranspose.copy(covar4D).transpose(); // final result in covar4d transpose
        covar4DFinal.multiplyMatrices(covar4DTranspose, covar4D);

        // Set covariance 4d array for textures
        const cov4DBase = SplatBuffer.Covariance4DSizeFloats * (i + destOffset);
        covariances4DArray[cov4DBase] = covar4DFinal.elements[12];
        covariances4DArray[cov4DBase + 1] = covar4DFinal.elements[13];
        covariances4DArray[cov4DBase + 2] = covar4DFinal.elements[14];
        covariances4DArray[cov4DBase + 3] = covar4DFinal.elements[15];

        // 4d to 3d covar
        covarianceMatrix.setFromMatrix4(covar4DFinal);

        const lsig = new THREE.Vector3(
          covar4DFinal.elements[12],
          covar4DFinal.elements[13],
          covar4DFinal.elements[14]
        );
        const rsig = new THREE.Vector3(
          covar4DFinal.elements[3],
          covar4DFinal.elements[7],
          covar4DFinal.elements[11]
        );
        const subtractor = new THREE.Matrix3();

        // Calculating the subtractor matrix
        for (let j = 0; j < 3; j++) {
          for (let k = 0; k < 3; k++) {
            // Calculate the index in the elements array
            let index = j * 3 + k;

            // Compute the value for the subtractor matrix
            subtractor.elements[index] =
              (lsig.getComponent(j) * rsig.getComponent(k)) /
              covar4DFinal.elements[15];
          }
        }
        transformedCovariance.copy(covarianceMatrix);
        for (let i = 0; i < 9; i++) {
          transformedCovariance.elements[i] -= subtractor.elements[i];
        }

        // if (i < 3) {
        //   console.log(
        //     i,
        //     " rotations l",
        //     rotation,
        //     JSON.stringify(rotationLeftMatrix.elements),
        //     "leftQuat",
        //     JSON.stringify(this.leftQuatToMatrix(rotation).elements)
        //   );
        //   console.log(
        //     i,
        //     " rotations r",
        //     rotationRight,
        //     JSON.stringify(rotationRightMatrix.elements)
        //   );
        //   console.log(
        //     i,
        //     " rotation 4d",
        //     JSON.stringify(rotation4DMatrix.elements)
        //   );
        //   console.log(i, " scale", JSON.stringify(scale4D.elements));
        //   console.log(i, " covar4d", JSON.stringify(covar4D.elements));
        //   console.log(
        //     i,
        //     " covar4d (transpose)",
        //     JSON.stringify(covar4DTranspose.elements)
        //   );
        //   console.log(
        //     i,
        //     " covar4d final",
        //     JSON.stringify(covar4DFinal.elements)
        //   );
        //   console.log(
        //     i,
        //     " covar3d from 4d",
        //     JSON.stringify(covarianceMatrix.elements)
        //   );
        //   console.log(i, " lsig", JSON.stringify(lsig));
        //   console.log(i, " rsig", JSON.stringify(rsig));
        //   console.log(i, " subtractor", JSON.stringify(subtractor.elements));
        //   console.log(
        //     i,
        //     "final covar3d from 4d",
        //     JSON.stringify(transformedCovariance.elements)
        //   );
        // }
      } else {
        const scaleBase = i * SplatBuffer.ScaleComponentCount;
        scale.set(
          this.fbf(this.scaleArray[scaleBase]),
          this.fbf(this.scaleArray[scaleBase + 1]),
          this.fbf(this.scaleArray[scaleBase + 2])
        );
        tempMatrix4.makeScale(scale.x, scale.y, scale.z);
        scaleMatrix.setFromMatrix4(tempMatrix4);

        tempMatrix4.makeRotationFromQuaternion(rotation);
        rotationMatrix.setFromMatrix4(tempMatrix4);
        covarianceMatrix.copy(rotationMatrix).multiply(scaleMatrix);
        transformedCovariance
          .copy(covarianceMatrix)
          .transpose()
          .premultiply(covarianceMatrix);
      }

      if (transform) {
        transform3x3.setFromMatrix4(transform);
        transform3x3Transpose.copy(transform3x3).transpose();
        transformedCovariance.multiply(transform3x3Transpose);
        transformedCovariance.premultiply(transform3x3);
      }

      const cov4DBase = SplatBuffer.Covariance4DSizeFloats * (i + destOffset);
      const covBase = SplatBuffer.CovarianceSizeFloats * (i + destOffset);
      covarianceArray[covBase] = transformedCovariance.elements[0];
      covarianceArray[covBase + 1] = transformedCovariance.elements[3];
      covarianceArray[covBase + 2] = transformedCovariance.elements[6];
      covarianceArray[covBase + 3] = transformedCovariance.elements[4];
      covarianceArray[covBase + 4] = transformedCovariance.elements[7];
      covarianceArray[covBase + 5] = transformedCovariance.elements[8];
    }
  }

  fillSplatColorArray(outColorArray, destOffset, transform) {
    const splatCount = this.splatCount;
    for (let i = 0; i < splatCount; i++) {
      const colorSrcBase = i * SplatBuffer.ColorComponentCount;
      const colorDestBase = (i + destOffset) * SplatBuffer.ColorComponentCount;
      outColorArray[colorDestBase] = this.colorArray[colorSrcBase];
      outColorArray[colorDestBase + 1] = this.colorArray[colorSrcBase + 1];
      outColorArray[colorDestBase + 2] = this.colorArray[colorSrcBase + 2];
      outColorArray[colorDestBase + 3] = this.colorArray[colorSrcBase + 3];
      // TODO: implement application of transform for spherical harmonics
    }
  }
}
