import * as tf from '@tensorflow/tfjs'

const license_wheel_labels = ['license', 'other', 'wheel'];
const car_6_parts_labels = ['front', 'frontLeft', 'frontRight', 'other', 'rear', 'rearLeft', 'rearRight'];
var modelLicenseWheel = undefined;
var modelSixParts = undefined;

const objectDetection = {
  carType: 'turismo',
  front: {
    width: 0.2273,
    height: 0.1,
    coordX: 0.3539,
    coordY: 0.7069
  },
  rear: {
    width: 0.2211,
    height: 0.1,
    coordX: 0.3414,
    coordY: 0.5692
  },
  frontLeft: {
    width: 0.1891,
    height: 0.3361,
    coordX: 0.2742,
    coordY: 0.5875
  },
  frontRight: {
    width: 0.1891,
    height: 0.3361,
    coordX: 0.4523,
    coordY: 0.5528
  },
  rearLeft: {
    width: 0.1891,
    height: 0.3361,
    coordX: 0.4555,
    coordY: 0.6097
  },
  rearRight: {
    width: 0.1891,
    height: 0.3361,
    coordX: 0.3227,
    coordY: 0.5708
  },
};

export async function loadBothModels(plateType, parts = []) {
  if (!modelLicenseWheel && !modelSixParts) {
    if (plateType=='big') {
      objectDetection.front = { width: 0.159375, height: 0.1430556, coordX: 0.3875, coordY: 0.7055556 };
      objectDetection.rear = { width: 0.1523438, height: 0.1388889, coordX: 0.3765625, coordY: 0.5597222 };
    }

    parts.forEach((part) => {
      // Get ref of object
      const type = objectDetection[part.zone] || {};
      var { area: { width, height }, coordinates: { x: coordX, y: coordY } } = part;
      // Update values of type
      Object.entries(type).forEach(([key, value]) => type[key] = { width, height, coordX, coordY }[key] || value);
    });

    return new Promise((resolveOrig, rejectOrig) =>  {
      let modelWheel = `../assets/models/license_wheel/model.json`;
      let modelSix = `../assets/models/car_6_parts/model.json`;
      let modelWheelPromise = new Promise((resolve,reject) => {
        resolve(tf.loadLayersModel(modelWheel))
      });
      let modelSixPromise = new Promise((resolve,reject) => {
        resolve(tf.loadLayersModel(modelSix))
      });
      Promise.all([modelWheelPromise,modelSixPromise])
        .then(detectionModels => {
          try {
            tf.tidy(() => {
              detectionModels[0].predict(tf.zeros([1, 128, 128, 3]));
              detectionModels[1].predict(tf.zeros([1, 224, 224, 3]));
            })
            modelLicenseWheel = detectionModels[0];
            modelSixParts = detectionModels[1];
            resolveOrig(false);
          } catch (err) {
            console.log("Error inicializing models",err);
            rejectOrig(err);
          }
        })
        .catch(err => {
          console.log("Error loading models", err);
          rejectOrig(err);
        })
    })
  }
}

export async function loadObjectFittingModel(plateType, parts = []) {
  if (!modelLicenseWheel) {
    var modelWheel = `../assets/models/license_wheel/model.json`;
    if (plateType=='big') {
      objectDetection.front = { width: 0.159375, height: 0.1430556, coordX: 0.3875, coordY: 0.7055556 };
      objectDetection.rear = { width: 0.1523438, height: 0.1388889, coordX: 0.3765625, coordY: 0.5597222 };
    }

    parts.forEach((part) => {
      // Get ref of object
      const type = objectDetection[part.zone] || {};
      var { area: { width, height }, coordinates: { x: coordX, y: coordY } } = part;
      // Update values of type
      Object.entries(type).forEach(([key, value]) => type[key] = { width, height, coordX, coordY }[key] || value);
    });
    return new Promise((resolveOrig, rejectOrig) =>  {
      return new Promise((resolve, reject) => {
        resolve(tf.loadLayersModel(modelWheel))
      })
        .then(detectionModel => {
          try {
            tf.tidy(() => {
              detectionModel.predict(tf.zeros([1, 128, 128, 3]));
            })
            modelLicenseWheel = detectionModel
            resolveOrig(false);
          } catch (err) {
            console.log("Error inicializing models",err);
            rejectOrig(err);
          }
        })
        .catch(err => {
          console.log("Error loading models", err);
          rejectOrig(err);
        })
    })
  }
}

export function cropObjectFitting(ctx, width, height, imageType) {
  const object2detect = objectDetection[imageType];
  const croppedwidth = object2detect.width*width;
  const croppedheight = object2detect.height*height;
  const iX = object2detect.coordX*width;
  const iY = object2detect.coordY*height;
  return ctx.getImageData(iX,iY,croppedwidth,croppedheight)
}

export async function predictCarPartOcclusion(img) {
  return tf.tidy(() => {
    const image = tf.browser.fromPixels(img).toFloat();
    const resized = tf.image.resizeBilinear(image, [224, 224]);
    const offset = tf.scalar(255 / 2);
    const normalized = resized.sub(offset).div(offset);
    const input =  normalized.expandDims(0);
    let pred = Array.from(modelSixParts.predict(input).dataSync());
    let predictions = car_6_parts_labels
      .map((label, index) => ({ label, accuracy: pred[index] }))
      .sort((a, b) => b.accuracy - a.accuracy)
    input.dispose();
    return predictions;
  });
}

export async function predictObjectFitting(img) {
  return tf.tidy(() => {
    const image = tf.browser.fromPixels(img).toFloat();
    const resized = tf.image.resizeBilinear(image, [128, 128]);
    const offset = tf.scalar(255 / 2);
    const normalized = resized.sub(offset).div(offset);
    const input =  normalized.expandDims(0);
    let pred = Array.from(modelLicenseWheel.predict(input).dataSync());
    let predictions = license_wheel_labels
      .map((label, index) => ({ label, accuracy: pred[index] }))
      .sort((a, b) => b.accuracy - a.accuracy)
    input.dispose();
    return predictions;
  });
}