Use the least possible memory for boot signing and verification

Close #971, close #966
This commit is contained in:
topjohnwu 2019-01-16 17:12:23 -05:00
parent 23e5188422
commit 85042fbe25
2 changed files with 84 additions and 32 deletions

View File

@ -4,20 +4,18 @@ import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
public class ByteArrayStream extends ByteArrayOutputStream {
public byte[] getBuf() {
return buf;
}
public synchronized void readFrom(InputStream is) {
readFrom(is, Integer.MAX_VALUE);
}
public synchronized void readFrom(InputStream is, int len) {
int read;
byte buffer[] = new byte[4096];
try {
while ((read = is.read(buffer, 0, len < buffer.length ? len : buffer.length)) > 0) {
while ((read = is.read(buffer, 0, Math.min(len, buffer.length))) > 0) {
write(buffer, 0, read);
len -= read;
}
@ -25,9 +23,7 @@ public class ByteArrayStream extends ByteArrayOutputStream {
e.printStackTrace();
}
}
public synchronized void writeTo(OutputStream out, int off, int len) throws IOException {
out.write(buf, off, len);
}
public ByteArrayInputStream getInputStream() {
return new ByteArrayInputStream(buf, 0, count);
}

View File

@ -15,6 +15,7 @@ import org.bouncycastle.asn1.x509.AlgorithmIdentifier;
import org.bouncycastle.jce.provider.BouncyCastleProvider;
import java.io.ByteArrayInputStream;
import java.io.FilterInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
@ -35,20 +36,62 @@ public class SignBoot {
Security.addProvider(new BouncyCastleProvider());
}
private static class PushBackRWStream extends FilterInputStream {
private OutputStream out;
private int pos = 0;
private byte[] backBuf;
PushBackRWStream(InputStream in, OutputStream o) {
super(in);
out = o;
}
@Override
public int read() throws IOException {
int b;
if (backBuf != null && backBuf.length > pos) {
b = backBuf[pos++];
} else {
b = super.read();
out.write(b);
}
return b;
}
@Override
public int read(byte[] bytes, int off, int len) throws IOException {
int read = 0;
if (backBuf != null && backBuf.length > pos) {
read = Math.min(len, backBuf.length - pos);
System.arraycopy(backBuf, pos, bytes, off, read);
pos += read;
off += read;
len -= read;
}
if (len > 0) {
int ar = super.read(bytes, off, len);
read += ar;
out.write(bytes, off, ar);
}
return read;
}
void unread(byte[] buf) {
backBuf = buf;
}
}
public static boolean doSignature(String target, InputStream imgIn, OutputStream imgOut,
InputStream cert, InputStream key) {
try {
ByteArrayStream image = new ByteArrayStream();
image.readFrom(imgIn);
int signableSize = getSignableImageSize(image.getBuf());
if (signableSize < image.size()) {
System.err.println("NOTE: truncating input from " +
image.size() + " to " + signableSize + " bytes");
} else if (signableSize > image.size()) {
throw new IllegalArgumentException("Invalid image: too short, expected " +
signableSize + " bytes");
}
BootSignature bootsig = new BootSignature(target, image.size());
PushBackRWStream in = new PushBackRWStream(imgIn, imgOut);
byte[] hdr = new byte[1024];
// First read the header
in.read(hdr);
int signableSize = getSignableImageSize(hdr);
// Unread header
in.unread(hdr);
BootSignature bootsig = new BootSignature(target, signableSize);
if (cert == null) {
cert = SignBoot.class.getResourceAsStream("/keys/testkey.x509.pem");
}
@ -58,10 +101,9 @@ public class SignBoot {
key = SignBoot.class.getResourceAsStream("/keys/testkey.pk8");
}
PrivateKey privateKey = CryptoUtils.readPrivateKey(key);
bootsig.setSignature(bootsig.sign(privateKey, image.getBuf(), signableSize),
CryptoUtils.getSignatureAlgorithmIdentifier(privateKey));
byte[] sig = bootsig.sign(privateKey, in, signableSize);
bootsig.setSignature(sig, CryptoUtils.getSignatureAlgorithmIdentifier(privateKey));
byte[] encoded_bootsig = bootsig.getEncoded();
image.writeTo(imgOut);
imgOut.write(encoded_bootsig);
imgOut.flush();
return true;
@ -73,19 +115,29 @@ public class SignBoot {
public static boolean verifySignature(InputStream imgIn, InputStream certIn) {
try {
ByteArrayStream image = new ByteArrayStream();
image.readFrom(imgIn);
int signableSize = getSignableImageSize(image.getBuf());
if (signableSize >= image.size()) {
// Read the header for size
byte[] hdr = new byte[1024];
if (imgIn.read(hdr) != hdr.length)
return false;
int signableSize = getSignableImageSize(hdr);
// Read the rest of the image
byte[] rawImg = Arrays.copyOf(hdr, signableSize);
int remain = signableSize - hdr.length;
if (imgIn.read(rawImg, hdr.length, remain) != remain) {
System.err.println("Invalid image: not signed");
return false;
}
byte[] signature = Arrays.copyOfRange(image.getBuf(), signableSize, image.size());
// Read footer, which contains the signature
byte[] signature = new byte[4096];
imgIn.read(signature);
BootSignature bootsig = new BootSignature(signature);
if (certIn != null) {
bootsig.setCertificate(CryptoUtils.readCertificate(certIn));
}
if (bootsig.verify(image.getBuf(), signableSize)) {
if (bootsig.verify(rawImg, signableSize)) {
System.err.println("Signature is VALID");
return true;
} else {
@ -148,8 +200,7 @@ public class SignBoot {
* Initializes the object for verifying a signed image file
* @param signature Signature footer
*/
public BootSignature(byte[] signature)
throws Exception {
public BootSignature(byte[] signature) throws Exception {
ASN1InputStream stream = new ASN1InputStream(signature);
ASN1Sequence sequence = (ASN1Sequence) stream.readObject();
formatVersion = (ASN1Integer) sequence.getObjectAt(0);
@ -193,10 +244,15 @@ public class SignBoot {
publicKey = cert.getPublicKey();
}
public byte[] sign(PrivateKey key, byte[] image, int length) throws Exception {
public byte[] sign(PrivateKey key, InputStream is, int len) throws Exception {
Signature signer = Signature.getInstance(CryptoUtils.getSignatureAlgorithm(key));
signer.initSign(key);
signer.update(image, 0, length);
int read;
byte buffer[] = new byte[4096];
while ((read = is.read(buffer, 0, Math.min(len, buffer.length))) > 0) {
signer.update(buffer, 0, read);
len -= read;
}
signer.update(getEncodedAuthenticatedAttributes());
return signer.sign();
}