
function centroid(points) {
  if (points.length === 0) return;
  const d = points[0].length;
  const center = Array.from({length: d}).fill(0);
  for (let i = 0; i < points.length; i++) {
      for (let j = 0; j < points[0].length; j++) {
          center[j] += points[i][j];
      }
  }
  return center.map(x => x/points.length);
}

function clusterCentroid(points, clusters, idx) {
  return centroid(points.filter((point, k) => clusters[k] === idx))
}


function squaredDistance(p1, p2) {
  let sum = 0;
  for (let j = 0; j < p1.length; j++) {
      const dist = p1[j] - p2[j];
      sum += dist*dist;
  }
  return sum;
}

function weightedRandomChoice(points, center) {
  const D2 = Array.from({length: points.length});
  let sum = 0;
  for (let j = 0; j < points.length; j++) {
    D2[j] = center ? squaredDistance(points[j], center) : 1.0;
    sum += D2[j];
  }
  for (let j = 0; j < points.length; j++) {
    D2[j] /= sum;
    if (j > 0) D2[j] = D2[j-1] + D2[j];
  }
  const p = Math.random();
  for (let j = 0; j < points.length; j++) { 
    if (p < D2[j]) return j;
  }
  return 0;
}

function assignClusters({points, centroids, clusters}) {
  clusters = clusters.slice();
  for (let i = 0; i < points.length; i++) {
      if (clusters[i] > -1) continue;
      let best = Infinity;
      for (let j = 0; j < centroids.length; j++) {
          const dist2 = squaredDistance(points[i], centroids[j]);
          if (dist2 < best) {
              best = dist2;
              clusters[i] = j;
          }
      }
  }
  return clusters;
}

function arraysEqual(arr1, arr2) {
  if (arr1.length !== arr2.length) return false;
  for (let i = 0; i < arr1.length; i++) {
      if (arr1[i] !== arr2[i]) return false;
  }
  return true;
}

function score({points, centroids, clusters}) {
  let sum = 0;
  for (let j = 0; j < points.length; j++) {
      sum += squaredDistance(points[j], centroids[clusters[j]]); 
  }
  return sum;
}

function indmin(array, f) {
  let ind = 0;
  let min = Infinity;
  for (let j = 0; j < array.length; j++) {
      const newf = f(array[j]);
      if (newf < min) {
          min = newf;
          ind = j;
      }
  }
  return ind;
}

export function kmeans({points, clusters, k, maxiter, numTries=1}) {
  if (numTries > 1) {
      const tries = [];
      for (let j = 0; j < numTries; j++) {
          tries.push(kmeans({points, clusters, k, maxiter, numTries: 1}));
      }
      return tries[indmin(tries, result => result.score)];
  }
  const centroids = Array.from({length: k});
  const presetClusters = Array.from(new Set(clusters)).sort().filter(x => x !== -1);
  // kmeans++ initialization for ones after preset clusters
  for (let j of presetClusters) {
      centroids[j] = clusterCentroid(points, clusters, j);
  }
  for (let j = 0; j < centroids.length; j++) {
      if (!centroids[j]) {
          const idx = weightedRandomChoice(points, centroid(centroids.filter(Boolean)));
          centroids[j] = points[idx];
      }
  }
  let iter = 0;
  let newClusters;
  while (true) {
      newClusters = assignClusters({points, centroids, clusters});
      let shouldBreak = true;
      for (let j = 0; j < k; j++) {
          let oldCentroid = centroids[j];
          centroids[j] = clusterCentroid(points, newClusters, j) || centroids[j];
          if (!arraysEqual(oldCentroid, centroids[j])) shouldBreak = false;
      }
      iter++;
      if (shouldBreak || iter >= maxiter) break;
  }
  return {
      clusters: newClusters,
      centroids,
      score: score({points, centroids, clusters: newClusters}),
      iter,
  };
}