public class Solution {
public String solve(String A, int B) {
if (B == 0) {
return A;
}
char[] ch = A.toCharArray();
TreeMap<Character, TreeSet<Integer>> map = new TreeMap<>((c1, c2) -> c2 - c1);
for (int i = 0; i < ch.length; i++) {
TreeSet<Integer> index = map.get(ch[i]);
if (index == null) {
index = new TreeSet<>((index1, index2) -> index2 - index1);
map.put(ch[i], index);
}
index.add(i);
}
TreeSet<Integer> pq = new TreeSet<>((index1, index2) -> ch[index1] - ch[index2]);
int i = 0;
while (i < ch.length && B > 0) {
Map.Entry<Character, TreeSet<Integer>> maxEntry = map.pollFirstEntry();
// System.out.println("Before: " + pq + " " + maxEntry.getKey() + "=" + maxEntry.getValue());
int j = 0;
while (B > 0 && j < maxEntry.getValue().size()) {
if (ch[i] != maxEntry.getKey()) {
pq.add(i);
j++;
B--;
}
else {
maxEntry.getValue().remove(i);
}
i++;
}
// System.out.println("After: " + pq + " " + maxEntry.getKey() + "=" + maxEntry.getValue());
while (!pq.isEmpty()) {
int index = pq.pollFirst();
int swapIndex = maxEntry.getValue().pollFirst();
char t = ch[index];
ch[index] = ch[swapIndex];
ch[swapIndex] = t;
// System.out.println("Swapped: " + new String(ch));
map.get(t).remove(index);
map.get(t).add(swapIndex);
}
}
return new String(ch);
}
}