Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,16 +25,16 @@
import static google.registry.request.Action.Method.POST;
import static jakarta.servlet.http.HttpServletResponse.SC_INTERNAL_SERVER_ERROR;
import static java.nio.charset.StandardCharsets.US_ASCII;
import static java.nio.charset.StandardCharsets.UTF_8;

import com.google.cloud.storage.BlobId;
import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.ImmutableSortedSet;
import com.google.common.collect.Ordering;
import com.google.common.flogger.FluentLogger;
import com.google.common.hash.Hasher;
import com.google.common.hash.Hashing;
import com.google.common.io.ByteSource;
import google.registry.bsa.api.BsaCredential;
import google.registry.config.RegistryConfig.Config;
import google.registry.gcs.GcsUtils;
Expand All @@ -47,10 +47,13 @@
import google.registry.util.Clock;
import jakarta.inject.Inject;
import jakarta.persistence.TypedQuery;
import java.io.ByteArrayOutputStream;
import java.io.BufferedInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.PipedInputStream;
import java.io.PipedOutputStream;
import java.io.Writer;
import java.util.Optional;
import java.util.zip.GZIPOutputStream;
Expand All @@ -60,14 +63,17 @@
import okhttp3.Request;
import okhttp3.RequestBody;
import okhttp3.Response;
import okio.BufferedSink;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
import org.joda.time.DateTime;

/**
* Daily action that uploads unavailable domain names on applicable TLDs to BSA.
*
* <p>The upload is a single zipped text file containing combined details for all BSA-enrolled TLDs.
* The text is a newline-delimited list of punycoded fully qualified domain names, and contains all
* domains on each TLD that are registered and/or reserved.
* The text is a newline-delimited list of punycoded fully qualified domain names with a trailing
* newline at the end, and contains all domains on each TLD that are registered and/or reserved.
*
* <p>The file is also uploaded to GCS to preserve it as a record for ourselves.
*/
Expand Down Expand Up @@ -118,7 +124,7 @@ public void run() {
// TODO(mcilwain): Implement a date Cursor, have the cronjob run frequently, and short-circuit
// the run if the daily upload is already completed.
DateTime runTime = clock.nowUtc();
String unavailableDomains = Joiner.on("\n").join(getUnavailableDomains(runTime));
ImmutableSortedSet<String> unavailableDomains = getUnavailableDomains(runTime);
if (unavailableDomains.isEmpty()) {
logger.atWarning().log("No unavailable domains found; terminating.");
emailSender.sendNotification(
Expand All @@ -136,12 +142,15 @@ public void run() {
}

/** Uploads the unavailable domains list to GCS in the unavailable domains bucket. */
boolean uploadToGcs(String unavailableDomains, DateTime runTime) {
boolean uploadToGcs(ImmutableSortedSet<String> unavailableDomains, DateTime runTime) {
logger.atInfo().log("Uploading unavailable names file to GCS in bucket %s", gcsBucket);
BlobId blobId = BlobId.of(gcsBucket, createFilename(runTime));
try (OutputStream gcsOutput = gcsUtils.openOutputStream(blobId);
Writer osWriter = new OutputStreamWriter(gcsOutput, US_ASCII)) {
osWriter.write(unavailableDomains);
for (var domainName : unavailableDomains) {
osWriter.write(domainName);
osWriter.write("\n");
}
return true;
} catch (Exception e) {
logger.atSevere().withCause(e).log(
Expand All @@ -150,10 +159,14 @@ boolean uploadToGcs(String unavailableDomains, DateTime runTime) {
}
}

boolean uploadToBsa(String unavailableDomains, DateTime runTime) {
boolean uploadToBsa(ImmutableSortedSet<String> unavailableDomains, DateTime runTime) {
try {
byte[] gzippedContents = gzipUnavailableDomains(unavailableDomains);
String sha512Hash = ByteSource.wrap(gzippedContents).hash(Hashing.sha512()).toString();
Hasher sha512Hasher = Hashing.sha512().newHasher();
unavailableDomains.stream()
.map(name -> name + "\n")
.forEachOrdered(line -> sha512Hasher.putString(line, UTF_8));
String sha512Hash = sha512Hasher.hash().toString();

String filename = createFilename(runTime);
OkHttpClient client = new OkHttpClient().newBuilder().build();

Expand All @@ -169,7 +182,9 @@ boolean uploadToBsa(String unavailableDomains, DateTime runTime) {
.addFormDataPart(
"file",
String.format("%s.gz", filename),
RequestBody.create(gzippedContents, MediaType.parse("application/octet-stream")))
new StreamingRequestBody(
gzippedStream(unavailableDomains),
MediaType.parse("application/octet-stream")))
.build();

Request request =
Expand All @@ -196,15 +211,6 @@ boolean uploadToBsa(String unavailableDomains, DateTime runTime) {
}
}

private byte[] gzipUnavailableDomains(String unavailableDomains) throws IOException {
try (ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream()) {
try (GZIPOutputStream gzipOutputStream = new GZIPOutputStream(byteArrayOutputStream)) {
gzipOutputStream.write(unavailableDomains.getBytes(US_ASCII));
}
return byteArrayOutputStream.toByteArray();
}
}

private static String createFilename(DateTime runTime) {
return String.format("unavailable_domains_%s.txt", runTime.toString());
}
Expand Down Expand Up @@ -280,4 +286,64 @@ private ImmutableSortedSet<String> getUnavailableDomains(DateTime runTime) {
private static String toDomain(String domainLabel, Tld tld) {
return String.format("%s.%s", domainLabel, tld.getTldStr());
}

private InputStream gzippedStream(ImmutableSortedSet<String> unavailableDomains)
throws IOException {
PipedInputStream inputStream = new PipedInputStream();
PipedOutputStream outputStream = new PipedOutputStream(inputStream);

new Thread(
() -> {
try {
gzipUnavailableDomains(outputStream, unavailableDomains);
} catch (Throwable e) {
logger.atSevere().withCause(e).log("Failed to gzip unavailable domains.");
try {
// This will cause the next read to throw an IOException.
inputStream.close();
} catch (IOException ignore) {
//
}
}
})
.start();

return inputStream;
}

private void gzipUnavailableDomains(
PipedOutputStream outputStream, ImmutableSortedSet<String> unavailableDomains)
throws IOException {
try (GZIPOutputStream gzipOutputStream = new GZIPOutputStream(outputStream)) {
for (String name : unavailableDomains) {
var line = name + "\n";
gzipOutputStream.write(line.getBytes(US_ASCII));
}
}
}

private static class StreamingRequestBody extends RequestBody {
private final BufferedInputStream inputStream;
private final MediaType mediaType;

StreamingRequestBody(InputStream inputStream, MediaType mediaType) {
this.inputStream = new BufferedInputStream(inputStream);
this.mediaType = mediaType;
}

@Nullable
@Override
public MediaType contentType() {
return mediaType;
}

@Override
public void writeTo(@NotNull BufferedSink bufferedSink) throws IOException {
byte[] buffer = new byte[2048];
int bytesRead;
while ((bytesRead = inputStream.read(buffer)) != -1) {
bufferedSink.write(buffer, 0, bytesRead);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,24 @@
import static google.registry.testing.DatabaseHelper.persistDeletedDomain;
import static google.registry.testing.DatabaseHelper.persistReservedList;
import static google.registry.testing.DatabaseHelper.persistResource;
import static google.registry.testing.LogsSubject.assertAboutLogs;
import static google.registry.util.DateTimeUtils.START_OF_TIME;
import static google.registry.util.NetworkUtils.pickUnusedPort;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.concurrent.Executors.newSingleThreadExecutor;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

import com.google.cloud.storage.BlobId;
import com.google.cloud.storage.contrib.nio.testing.LocalStorageHelper;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.flogger.FluentLogger;
import com.google.common.hash.Hashing;
import com.google.common.io.ByteStreams;
import com.google.common.net.HostAndPort;
import com.google.common.testing.TestLogHandler;
import com.google.gson.Gson;
import google.registry.bsa.api.BsaCredential;
import google.registry.gcs.GcsUtils;
import google.registry.model.tld.Tld;
Expand All @@ -35,12 +46,29 @@
import google.registry.persistence.transaction.JpaTestExtensions;
import google.registry.persistence.transaction.JpaTestExtensions.JpaIntegrationTestExtension;
import google.registry.request.UrlConnectionService;
import google.registry.server.Route;
import google.registry.server.TestServer;
import google.registry.testing.FakeClock;
import google.registry.testing.FakeResponse;
import jakarta.servlet.ServletException;
import jakarta.servlet.annotation.MultipartConfig;
import jakarta.servlet.http.HttpServlet;
import jakarta.servlet.http.HttpServletRequest;
import jakarta.servlet.http.HttpServletResponse;
import jakarta.servlet.http.Part;
import java.io.IOException;
import java.io.InputStream;
import java.io.PrintWriter;
import java.net.InetAddress;
import java.util.Map;
import java.util.Optional;
import java.util.logging.Level;
import java.util.logging.Logger;
import java.util.zip.GZIPInputStream;
import org.joda.time.DateTime;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.condition.DisabledIfEnvironmentVariable;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.api.extension.RegisterExtension;
import org.mockito.Mock;
Expand Down Expand Up @@ -102,13 +130,114 @@ void calculatesEntriesCorrectly() throws Exception {
BlobId existingFile =
BlobId.of(BUCKET, String.format("unavailable_domains_%s.txt", clock.nowUtc()));
String blockList = new String(gcsUtils.readBytesFrom(existingFile), UTF_8);
assertThat(blockList).isEqualTo("ace.tld\nflagrant.tld\nfoobar.tld\njimmy.tld\ntine.tld");
assertThat(blockList).isEqualTo("ace.tld\nflagrant.tld\nfoobar.tld\njimmy.tld\ntine.tld\n");
assertThat(blockList).doesNotContain("not-blocked.tld");

// This test currently fails in the upload-to-bsa step.
verify(emailSender, times(1))
.sendNotification("BSA daily upload completed with errors", "Please see logs for details.");
}

// TODO(weiminyu): Breaks other tests on Kokoro. Investigate.
@DisabledIfEnvironmentVariable(named = "KOKORO_JOB_NAME", matches = ".*")
@Test
void uploadToBsaTest() throws Exception {
TestLogHandler logHandler = new TestLogHandler();
Logger loggerToIntercept =
Logger.getLogger(UploadBsaUnavailableDomainsAction.class.getCanonicalName());
loggerToIntercept.addHandler(logHandler);

persistActiveDomain("foobar.tld");
persistActiveDomain("ace.tld");
persistDeletedDomain("not-blocked.tld", clock.nowUtc().minusDays(1));

var testServer = startTestServer();
action.apiUrl = testServer.getUrl("/upload").toURI().toString();
try {
action.run();
} finally {
testServer.stop();
}
String dataSent = "ace.tld\nflagrant.tld\nfoobar.tld\njimmy.tld\ntine.tld\n";
String checkSum = Hashing.sha512().hashString(dataSent, UTF_8).toString();
String expectedResponse =
"Received response with code 200 from server: "
+ String.format("Checksum: [%s]\n%s\n", checkSum, dataSent);
assertAboutLogs().that(logHandler).hasLogAtLevelWithMessage(Level.INFO, expectedResponse);
verify(emailSender, times(1)).sendNotification("BSA daily upload completed successfully", "");
}

private TestServer startTestServer() throws Exception {
TestServer testServer =
new TestServer(
HostAndPort.fromParts(InetAddress.getLocalHost().getHostAddress(), pickUnusedPort()),
ImmutableMap.of(),
ImmutableList.of(Route.route("/upload", Servelet.class)));
testServer.start();
newSingleThreadExecutor()
.execute(
() -> {
try {
while (true) {
testServer.process();
}
} catch (InterruptedException e) {
// Expected
}
});
return testServer;
}

// TODO(mcilwain): Add test of BSA API upload as well.
@MultipartConfig(
location = "", // Directory for storing uploaded files. Use default when blank
maxFileSize = 10485760L, // 10MB
maxRequestSize = 20971520L, // 20MB
fileSizeThreshold = 1048576 // Save in memory if file size < 1MB
)
public static class Servelet extends HttpServlet {
private static final FluentLogger logger = FluentLogger.forEnclosingClass();

@Override
protected void doPost(HttpServletRequest req, HttpServletResponse resp)
throws ServletException, IOException {
String checkSum = null;
String content = null;
try {
for (Part part : req.getParts()) {
switch (part.getName()) {
case "zone" -> checkSum = readChecksum(part);
case "file" -> content = readGzipped(part);
}
}
} catch (Exception e) {
logger.atInfo().withCause(e).log("");
}
int status = checkSum == null || content == null ? 400 : 200;
resp.setStatus(status);
resp.setContentType("text/plain");
try (PrintWriter writer = resp.getWriter()) {
writer.printf("Checksum: [%s]\n%s\n", checkSum, content);
}
}

private String readChecksum(Part part) {
try (InputStream is = part.getInputStream()) {
return new Gson()
.fromJson(new String(ByteStreams.toByteArray(is), UTF_8), Map.class)
.getOrDefault("checkSum", "Not found")
.toString();
} catch (IOException e) {
throw new RuntimeException(e);
}
}

private String readGzipped(Part part) {
try (InputStream is = part.getInputStream();
GZIPInputStream gis = new GZIPInputStream(is)) {
return new String(ByteStreams.toByteArray(gis), UTF_8);
} catch (IOException e) {
throw new RuntimeException(e);
}
}
}
}
Loading
Loading