tdlight/UpdateMemoryManager.javash
2024-03-20 23:26:07 +01:00

379 lines
12 KiB
Plaintext
Executable File

#!/usr/bin/java --source 21
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.nio.file.FileVisitOption;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Objects;
import java.util.SequencedSet;
import java.util.Set;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
public class UpdateMemoryManager {
static final Set<String> EXCLUDED_MANAGERS = Set.of("Memory",
"Call",
"DeviceToken",
"LanguagePack",
"Pts",
"Password",
"SecretChats",
"Secure",
"Config",
"Storage",
"FileLoad",
"Parts",
"FileGenerate",
"Resource",
"NetStats",
"DcAuth",
"State",
"PhoneNumber"
);
record Manager(Path directory, String name) {
String includePath() {
return directory.toString().substring(2) + "/" + name + "Manager.h";
}
}
public static void main(String[] args) throws Exception {
if (args.length == 0) {
System.err.println("Arguments: PATH");
System.exit(1);
}
var path = Path.of(args[0]);
var telegramPath = path.resolve("td/telegram");
var memoryManager = new Manager(telegramPath, "Memory");
SequencedSet<Manager> managers;
try (var stream = Files.walk(telegramPath, 16)) {
var endNamePattern = "Manager.h";
managers = stream
.filter(Files::isRegularFile)
.filter(p -> p.getFileName().toString().endsWith(endNamePattern))
.map(p -> new Manager(p.getParent(), p.getFileName().toString().substring(0, p.getFileName().toString().length() - endNamePattern.length())))
.filter(name -> !EXCLUDED_MANAGERS.contains(name.name))
.collect(Collectors.toCollection(LinkedHashSet::new));
}
System.out.printf("Found %d managers%n", managers.size());
int fieldsFound = 0;
int updatedManagers = 0;
int totalManagers = 0;
List<Manager> invalidManagers = new ArrayList<>();
for (Manager manager : managers) {
totalManagers++;
var result = updateManagerJson(manager);
if (result != null) {
fieldsFound += result.fieldsFound();
if (result.changed) {
updatedManagers++;
}
} else {
invalidManagers.add(manager);
}
}
updateMemoryManagerJson(memoryManager, managers);
System.out.printf("%n%nDone.%n");
if (!invalidManagers.isEmpty()) {
System.out.printf("%d invalid managers found:%n%s",
invalidManagers.size(),
invalidManagers.stream()
.map(x -> "\t\"" + x.directory + "\": " + x.name + "\n")
.collect(Collectors.joining(", ")));
}
System.out.printf("%d/%d managers updated, %d total fields%n", updatedManagers, totalManagers, fieldsFound);
}
enum FieldType {
WaitFreeHashMap("WaitFreeHashMap", "calc_size"),
WaitFreeHashSet("WaitFreeHashSet", "calc_size"),
Vector("vector", "size"),
FlatHashMap("FlatHashMap", "size"),
FlatHashSet("FlatHashSet", "size")
;
private final String fieldName;
private final String sizeMethodName;
public final Pattern pattern;
FieldType(String fieldName, String sizeMethodName) {
this.fieldName = fieldName;
this.sizeMethodName = sizeMethodName;
this.pattern = Pattern.compile("^ {2}(mutable )?" + fieldName + "(<([^ ]|, )+>)? +(?<field>[a-zA-Z_]+);?[ \t/]*$");
}
Pattern getPattern() {
return pattern;
}
}
record FoundField(FieldType type, String name) {}
record UpdateResult(boolean changed, int fieldsFound) {}
private static UpdateResult updateManagerJson(Manager manager) throws IOException {
Path hFile = manager.directory.resolve(manager.name + "Manager.h");
Path cppFile = manager.directory.resolve(manager.name + "Manager.cpp");
System.out.printf("Updating manager \"%s\" files: [\"%s\", \"%s\"]%n", manager.name, hFile, cppFile);
if (Files.notExists(hFile)) {
System.out.printf("File not found, ignoring manager \"%s\": \"%s\"%n", manager.name, hFile);
return null;
}
if (Files.notExists(cppFile)) {
System.out.printf("File not found, ignoring manager \"%s\": \"%s\"%n", manager.name, cppFile);
return null;
}
List<FoundField> fields = new ArrayList<>();
var hLines = normalizeSourceFile(readSourceFile(hFile));
boolean currentClass = false;
for (String hLine : hLines) {
FoundField field = null;
if (hLine.startsWith("class ")) {
currentClass = hLine.contains(" " + manager.name + "Manager");
}
if (currentClass) {
for (FieldType possibleFieldType : FieldType.values()) {
var m = possibleFieldType.getPattern().matcher(hLine);
if (m.matches()) {
var fieldName = m.group("field");
field = new FoundField(possibleFieldType, fieldName);
break;
}
}
}
if (field != null) {
System.out.println("\tFound field: (%s) %s".formatted(field.type, field.name));
fields.add(field);
}
}
StringBuilder memoryStatsMethod = new StringBuilder();
memoryStatsMethod.append("void %sManager::memory_stats(vector<string> &output) {\n".formatted(manager.name));
memoryStatsMethod.append(fields.stream()
.map(field -> " output.emplace_back(\"\\\"%s\\\":\"); output.emplace_back(std::to_string(this->%s.%s()));\n".formatted(field.name, field.name, field.type.sizeMethodName))
.collect(Collectors.joining(" output.emplace_back(\",\");\n")));
memoryStatsMethod.append("}\n");
List<String> memoryStatsMethodLines = Arrays.asList(memoryStatsMethod.toString().split("\n"));
var cppLines = readSourceFile(cppFile);
var inputCppLines = new ArrayList<>(cppLines);
// Remove the old memory_stats method
var indexOfMemoryStatsStart = -1;
var indexOfMemoryStatsEnd = -1;
for (int i = 0; i < cppLines.size(); i++) {
if (cppLines.get(i).contains("::memory_stats(")) {
indexOfMemoryStatsStart = i;
for (int j = i - 1; j >= 0; j--) {
if (cppLines.get(j).isBlank()) {
indexOfMemoryStatsStart = j;
} else {
break;
}
}
break;
}
}
if (indexOfMemoryStatsStart != -1) {
for (int i = indexOfMemoryStatsStart + 1; i < cppLines.size(); i++) {
if (cppLines.get(i).trim().equals("}")) {
indexOfMemoryStatsEnd = i;
break;
}
}
if (indexOfMemoryStatsEnd == -1) {
throw new IllegalStateException("memory_stats method end not found");
}
cppLines.subList(indexOfMemoryStatsStart, indexOfMemoryStatsEnd + 1).clear();
}
var last = cppLines.removeLast();
cppLines.addAll(memoryStatsMethodLines);
cppLines.add("");
cppLines.addLast(last);
boolean changed = !Objects.equals(inputCppLines, cppLines);
if (changed) {
System.out.printf("\tDone: %s.cpp file has been updated!%n", manager.name);
Files.write(cppFile, cppLines, StandardCharsets.UTF_8);
} else {
System.out.printf("\tDone: %s.cpp file did not change.%n", manager.name);
}
return new UpdateResult(changed, fields.size());
}
private static void updateMemoryManagerJson(Manager manager, SequencedSet<Manager> managers) throws IOException {
Path hFile = manager.directory.resolve(manager.name + "Manager.h");
Path cppFile = manager.directory.resolve(manager.name + "Manager.cpp");
System.out.printf("Updating memory manager \"%s\" files: [\"%s\", \"%s\"]%n", manager.name, hFile, cppFile);
if (Files.notExists(hFile)) {
System.out.printf("File not found for manager \"%s\": \"%s\"%n", manager.name, hFile);
System.exit(1);
return;
}
if (Files.notExists(cppFile)) {
System.out.printf("File not found for manager \"%s\": \"%s\"%n", manager.name, cppFile);
System.exit(1);
return;
}
StringBuilder memoryStatsMethod = new StringBuilder();
memoryStatsMethod.append("void %sManager::print_managers_memory_stats(vector<string> &output) const {\n".formatted(manager.name));
memoryStatsMethod.append(managers.stream()
.map(m -> """
output.emplace_back("\\"%s_manager_\\":{"); td_->%s_manager_->memory_stats(output); output.emplace_back("}");
""".formatted(toSnakeCase(m.name), toSnakeCase(m.name)))
.collect(Collectors.joining(" output.emplace_back(\",\");\n")));
memoryStatsMethod.append("}\n");
List<String> memoryStatsMethodLines = Arrays.asList(memoryStatsMethod.toString().split("\n"));
var cppLines = readSourceFile(cppFile);
var inputCppLines = new ArrayList<>(cppLines);
// Remove the old memory_stats method
var indexOfMemoryStatsStart = -1;
var indexOfMemoryStatsEnd = -1;
for (int i = 0; i < cppLines.size(); i++) {
if (cppLines.get(i).contains("::print_managers_memory_stats(")) {
indexOfMemoryStatsStart = i;
for (int j = i - 1; j >= 0; j--) {
if (cppLines.get(j).isBlank()) {
indexOfMemoryStatsStart = j;
} else {
break;
}
}
break;
}
}
if (indexOfMemoryStatsStart != -1) {
for (int i = indexOfMemoryStatsStart + 1; i < cppLines.size(); i++) {
if (cppLines.get(i).trim().equals("}")) {
indexOfMemoryStatsEnd = i;
break;
}
}
if (indexOfMemoryStatsEnd == -1) {
throw new IllegalStateException("print_managers_memory_stats method end not found");
}
cppLines.subList(indexOfMemoryStatsStart, indexOfMemoryStatsEnd + 1).clear();
}
var last = cppLines.removeLast();
cppLines.addAll(memoryStatsMethodLines);
cppLines.add("");
cppLines.addLast(last);
for (Manager m : managers) {
var mInclude = "#include \"%s\"".formatted(m.includePath());
if (!cppLines.contains(mInclude)) {
System.out.printf("\tMissing include, adding \"" + m.includePath() + "\"%n");
int includeInsertIndex = -1;
for (int i = 0; i < cppLines.size(); i++) {
if (cppLines.get(i).startsWith("#include")) {
includeInsertIndex = i;
}
}
if (includeInsertIndex == -1) {
throw new IllegalStateException("Cannot find a place to put the include");
}
cppLines.add(includeInsertIndex + 1, mInclude);
}
}
boolean changed = !Objects.equals(inputCppLines, cppLines);
if (changed) {
System.out.printf("\tDone: %s.cpp file has been updated!%n", manager.name);
Files.write(cppFile, cppLines, StandardCharsets.UTF_8);
} else {
System.out.printf("\tDone: %s.cpp file did not change.%n", manager.name);
}
}
private static String toSnakeCase(String name) {
var initialChar = name.codePoints()
.limit(1)
.map(Character::toLowerCase);
var restOfString = name.codePoints()
.skip(1)
.flatMap(codePoint -> {
if (Character.isUpperCase(codePoint)) {
return IntStream.of('_', Character.toLowerCase(codePoint));
} else {
return IntStream.of(codePoint);
}
});
var resultStream = IntStream.concat(initialChar, restOfString);
return resultStream.collect(StringBuilder::new, StringBuilder::appendCodePoint, StringBuilder::append).toString();
}
private static List<String> readSourceFile(Path path) throws IOException {
return Files.readAllLines(path, StandardCharsets.UTF_8);
}
private static List<String> normalizeSourceFile(List<String> lines) throws IOException {
List<String> srcLines = new ArrayList<>(lines);
// Remove empty lines
srcLines.removeIf(String::isBlank);
// Remove unexpected newlines
List<String> srcLinesWithoutNewlines = new ArrayList<>();
StringBuilder buf = new StringBuilder();
for (String srcLine : srcLines) {
if (!buf.isEmpty()) {
buf.append(" ");
}
buf.append(srcLine);
var trimmedLine = srcLine.trim();
if (!trimmedLine.endsWith(">")) {
srcLinesWithoutNewlines.add(buf.toString());
buf.setLength(0);
}
}
if (!buf.isEmpty()) {
srcLinesWithoutNewlines.add(buf.toString());
}
srcLinesWithoutNewlines.replaceAll(p -> {
var commentStart = p.indexOf("//");
if (commentStart >= 0) {
return p.substring(0, commentStart);
} else {
return p;
}
});
return srcLinesWithoutNewlines;
}
}