diff --git a/app/src/full/java/com/topjohnwu/magisk/SuRequestActivity.java b/app/src/full/java/com/topjohnwu/magisk/SuRequestActivity.java index 0adf95651..55ab75f46 100644 --- a/app/src/full/java/com/topjohnwu/magisk/SuRequestActivity.java +++ b/app/src/full/java/com/topjohnwu/magisk/SuRequestActivity.java @@ -48,6 +48,11 @@ public class SuRequestActivity extends BaseActivity { class SuConnectorV1 extends SuConnector { SuConnectorV1(String name) throws IOException { + super(name); + } + + @Override + public void connect(String name) throws IOException { socket.connect(new LocalSocketAddress(name, LocalSocketAddress.Namespace.FILESYSTEM)); new FileObserver(name) { @Override @@ -60,28 +65,25 @@ public class SuRequestActivity extends BaseActivity { } @Override - public void response() { - try (OutputStream out = getOutputStream()) { - out.write((policy.policy == Policy.ALLOW ? "socket:ALLOW" : "socket:DENY").getBytes()); - } catch (IOException e) { - e.printStackTrace(); - } + public void onResponse() throws IOException { + out.write((policy.policy == Policy.ALLOW ? "socket:ALLOW" : "socket:DENY").getBytes()); } } class SuConnectorV2 extends SuConnector { SuConnectorV2(String name) throws IOException { + super(name); + } + + @Override + public void connect(String name) throws IOException { socket.connect(new LocalSocketAddress(name, LocalSocketAddress.Namespace.ABSTRACT)); } @Override - public void response() { - try (DataOutputStream out = getOutputStream()) { - out.writeInt(policy.policy); - } catch (IOException e) { - e.printStackTrace(); - } + public void onResponse() throws IOException { + out.writeInt(policy.policy); } } @@ -119,9 +121,9 @@ public class SuRequestActivity extends BaseActivity { // Get policy Intent intent = getIntent(); try { + String socketName = intent.getStringExtra("socket"); connector = intent.getIntExtra("version", 1) == 1 ? - new SuConnectorV1(intent.getStringExtra("socket")) : - new SuConnectorV2(intent.getStringExtra("socket")); + new SuConnectorV1(socketName) : new SuConnectorV2(socketName); Bundle bundle = connector.readSocketInput(); int uid = Integer.parseInt(bundle.getString("uid")); policy = mm.mDB.getPolicy(uid); diff --git a/app/src/full/java/com/topjohnwu/magisk/utils/SuConnector.java b/app/src/full/java/com/topjohnwu/magisk/utils/SuConnector.java index 1d6d1129f..7ca379186 100644 --- a/app/src/full/java/com/topjohnwu/magisk/utils/SuConnector.java +++ b/app/src/full/java/com/topjohnwu/magisk/utils/SuConnector.java @@ -15,6 +15,8 @@ import com.topjohnwu.magisk.R; import com.topjohnwu.magisk.container.Policy; import com.topjohnwu.magisk.container.SuLogEntry; +import java.io.BufferedInputStream; +import java.io.BufferedOutputStream; import java.io.DataInputStream; import java.io.DataOutputStream; import java.io.IOException; @@ -23,31 +25,50 @@ import java.util.Date; public abstract class SuConnector { protected LocalSocket socket = new LocalSocket(); + protected DataOutputStream out; + protected DataInputStream in; - private String readString(DataInputStream is) throws IOException { - int len = is.readInt(); + public SuConnector(String name) throws IOException { + connect(name); + out = new DataOutputStream(new BufferedOutputStream(socket.getOutputStream())); + in = new DataInputStream(new BufferedInputStream(socket.getInputStream())); + } + + private String readString() throws IOException { + int len = in.readInt(); byte[] buf = new byte[len]; - is.readFully(buf); - return new String(buf); + in.readFully(buf); + return new String(buf, "UTF-8"); } public Bundle readSocketInput() throws IOException { Bundle bundle = new Bundle(); - DataInputStream is = new DataInputStream(socket.getInputStream()); while (true) { - String name = readString(is); + String name = readString(); if (TextUtils.equals(name, "eof")) break; - bundle.putString(name, readString(is)); + bundle.putString(name, readString()); } return bundle; } - protected DataOutputStream getOutputStream() throws IOException { - return new DataOutputStream(socket.getOutputStream()); + public void response() { + try { + onResponse(); + out.flush(); + } catch (IOException e) { + e.printStackTrace(); + } + try { + in.close(); + out.close(); + socket.close(); + } catch (IOException ignored) { } } - public abstract void response(); + public abstract void connect(String name) throws IOException; + + protected abstract void onResponse() throws IOException; public static void handleLogs(Intent intent, int version) {