diff --git a/__main__.py b/__main__.py index 0598114..6e1cbe3 100644 --- a/__main__.py +++ b/__main__.py @@ -7,6 +7,34 @@ from code_writer import CodeWriter natives = ["int", "long", "boolean", "String", "double", "byte[]", "byte"] +native_to_object = { + "boolean": "Boolean", + "byte[]": "Byte", + "byte": "Byte", + "int": "Integer", + "short": "Short", + "char": "Character", + "long": "Long", + "float": "Float", + "double": "Double", + "String": "String" +} + +cmp_natives = natives.copy() +cmp_natives.remove("String") +cmp_natives.remove("byte[]") + + +def hash_object(name: str, typ: str) -> str: + if not typ.endswith("[]"): + return f"{name} == null ? 0 : {name}.hashCode()" + + elif typ.endswith("[][]"): + return f"Arrays.deepHashCode({name})" + + else: + return f"Arrays.hashCode({name})" + def deserialize_tdapi(output: CodeWriter, arg_name: str, arg_type: str, cont, classes, null_check: bool = True): if null_check: @@ -436,8 +464,8 @@ def main(input_path: str, output_path: str, headers_path: str): output.newline() output.close_block(space=True) + output.newline() - output.newline() output.indent() output.open_constructor_function(class_name, [("DataInput", "input")], "IOException") output.newline() @@ -609,6 +637,116 @@ def main(input_path: str, output_path: str, headers_path: str): output.close_block(space=True) output.close_block(space=True) + output.close_block(space=True) + + output.newline() + output.indent() + output.open_function("equals", [("java.lang.Object", "o")], "boolean") + output.newline() + + output.indent() + output.open_if("this == o") + output.newline() + output.indent() + output.ret("true") + output.newline() + output.close_block(space=True) + + output.indent() + output.open_if("o == null || getClass() != o.getClass()") + output.newline() + output.indent() + output.ret("false") + output.newline() + output.close_block(space=True) + + if class_meta[2]: + output.indent() + other_class = class_name[0].lower() + class_name[1:] + output.local_assign(class_name, other_class, f"({class_name}) o") + + output.newline() + + for arg_type, arg_name in class_meta[2]: + output.indent() + + if arg_type in cmp_natives: + output.open_if(f"this.{arg_name} != {other_class}.{arg_name}") + + elif not arg_type.endswith("[]"): + output.open_if(f"!Objects.equals(this.{arg_name}, {other_class}.{arg_name})") + + elif arg_type.endswith("[][]"): + output.open_if(f"!Arrays.deepEquals(this.{arg_name}, {other_class}.{arg_name})") + + else: + output.open_if(f"!Arrays.equals(this.{arg_name}, {other_class}.{arg_name})") + + output.newline() + output.indent() + output.ret("false") + output.newline() + output.close_block(space=True) + + output.indent() + output.ret("true") + output.newline() + + output.close_block(space=True) + + output.newline() + output.indent() + output.open_function("hashCode", [], "int") + output.newline() + + output.indent() + + if class_meta[2]: + primitives = [(t, n) for t, n in class_meta[2] if t in cmp_natives] + + if primitives and len(class_meta[2]) == 1: + output.ret(f"{native_to_object[primitives[0][0]]}.hashCode(this.{primitives[0][1]})") + + elif primitives: + output.local_assign("int", "result", + f"{native_to_object[primitives[0][0]]}.hashCode(this.{primitives[0][1]})") + output.newline() + output.indent() + + for arg_type, arg_name in primitives[1:]: + output.assign("result", f"result * 31 + " + f"{native_to_object[arg_type]}.hashCode(this.{arg_name})") + output.newline() + output.indent() + + tdapi = [(t, n) for t, n in class_meta[2] if n not in [p[1] for p in primitives]] + + if tdapi and len(class_meta[2]) == 1: + output.ret(hash_object(f"this.{tdapi[0][1]}", tdapi[0][0])) + start = 1 + + else: + if not primitives: + output.local_assign("int", "result", hash_object(f'this.{tdapi[0][1]}', tdapi[0][0])) + output.newline() + output.indent() + start = 1 + else: + start = 0 + + for arg_type, arg_name in tdapi[start:]: + output.assign("result", f"result * 31 + ({hash_object(f'this.{arg_name}', arg_type)})") + output.newline() + output.indent() + + if len(class_meta[2]) > 1: + output.ret("result") + + else: + output.ret("CONSTRUCTOR") + + output.newline() + output.close_block(space=True) output.close_block(space=True) output.newline() diff --git a/code_writer.py b/code_writer.py index 97b000c..e1c184c 100644 --- a/code_writer.py +++ b/code_writer.py @@ -60,6 +60,9 @@ class CodeWriter: def local_assign(self, object_type: str, name: str, value: str): self.fd.write(object_type + " " + name + " = " + value + ";") + def assign(self, name: str, value: str): + self.fd.write(name + " = " + value + ";") + def open_for(self, start: str, cond: str, stmt: str): self.indent_depth += 1 self.fd.write("for (" + start + "; " + cond + "; " + stmt + ") {")