diff --git a/pom.xml b/pom.xml index fc75358..8281b4c 100644 --- a/pom.xml +++ b/pom.xml @@ -165,6 +165,11 @@ junit-jupiter test + + org.mockito + mockito-core + test + diff --git a/src/test/java/de/rub/nds/crawler/config/ControllerCommandConfigTest.java b/src/test/java/de/rub/nds/crawler/config/ControllerCommandConfigTest.java new file mode 100644 index 0000000..1f372d5 --- /dev/null +++ b/src/test/java/de/rub/nds/crawler/config/ControllerCommandConfigTest.java @@ -0,0 +1,300 @@ +/* + * TLS-Crawler - A TLS scanning tool to perform large scale scans with the TLS-Scanner + * + * Copyright 2018-2022 Ruhr University Bochum, Paderborn University, and Hackmanit GmbH + * + * Licensed under Apache License, Version 2.0 + * http://www.apache.org/licenses/LICENSE-2.0.txt + */ +package de.rub.nds.crawler.config; + +import static org.junit.jupiter.api.Assertions.*; + +import com.beust.jcommander.JCommander; +import com.beust.jcommander.ParameterException; +import de.rub.nds.crawler.constant.CruxListNumber; +import de.rub.nds.crawler.core.BulkScanWorker; +import de.rub.nds.crawler.data.BulkScan; +import de.rub.nds.crawler.data.ScanConfig; +import de.rub.nds.crawler.targetlist.*; +import de.rub.nds.scanner.core.config.ScannerDetail; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +class ControllerCommandConfigTest { + + private TestControllerCommandConfig config; + + private static class TestControllerCommandConfig extends ControllerCommandConfig { + @Override + public ScanConfig getScanConfig() { + return new ScanConfig(ScannerDetail.NORMAL, 3, 2000) { + @Override + public BulkScanWorker createWorker( + String bulkScanID, int parallelConnectionThreads, int parallelScanThreads) { + return null; + } + }; + } + + @Override + public Class getScannerClassForVersion() { + return String.class; // Dummy class for testing + } + } + + @BeforeEach + void setUp() { + config = new TestControllerCommandConfig(); + } + + @Test + void testDefaultValues() { + assertEquals(443, config.getPort()); + assertEquals(ScannerDetail.NORMAL, config.getScanDetail()); + assertEquals(2000, config.getScannerTimeout()); + assertEquals(3, config.getReexecutions()); + assertNull(config.getScanCronInterval()); + assertNull(config.getScanName()); + assertNull(config.getHostFile()); + assertNull(config.getDenylistFile()); + assertFalse(config.isMonitored()); + assertNull(config.getNotifyUrl()); + assertEquals(0, config.getTranco()); + assertNull(config.getCrux()); + assertEquals(0, config.getTrancoEmail()); + } + + @Test + void testSettersAndGetters() { + config.setPort(8443); + assertEquals(8443, config.getPort()); + + config.setScanDetail(ScannerDetail.DETAILED); + assertEquals(ScannerDetail.DETAILED, config.getScanDetail()); + + config.setScannerTimeout(5000); + assertEquals(5000, config.getScannerTimeout()); + + config.setReexecutions(5); + assertEquals(5, config.getReexecutions()); + + config.setScanCronInterval("0 0 * * *"); + assertEquals("0 0 * * *", config.getScanCronInterval()); + + config.setScanName("test-scan"); + assertEquals("test-scan", config.getScanName()); + + config.setHostFile("/path/to/hosts"); + assertEquals("/path/to/hosts", config.getHostFile()); + + config.setDenylistFile("/path/to/denylist"); + assertEquals("/path/to/denylist", config.getDenylistFile()); + + config.setMonitored(true); + assertTrue(config.isMonitored()); + + config.setNotifyUrl("http://example.com/notify"); + assertEquals("http://example.com/notify", config.getNotifyUrl()); + + config.setTranco(1000); + assertEquals(1000, config.getTranco()); + + config.setCrux(CruxListNumber.TOP_10K); + assertEquals(CruxListNumber.TOP_10K, config.getCrux()); + + config.setTrancoEmail(500); + assertEquals(500, config.getTrancoEmail()); + } + + @Test + void testValidateNoTargetListProvider() { + // No host file, tranco, crux, or trancoEmail set + assertThrows(ParameterException.class, () -> config.validate()); + } + + @Test + void testValidateNotifyUrlWithoutMonitoring() { + config.setHostFile("/path/to/hosts"); + config.setNotifyUrl("http://example.com/notify"); + config.setMonitored(false); + + assertThrows(ParameterException.class, () -> config.validate()); + } + + @Test + void testValidateInvalidNotifyUrl() { + config.setHostFile("/path/to/hosts"); + config.setNotifyUrl("not-a-valid-url"); + config.setMonitored(true); + + assertThrows(ParameterException.class, () -> config.validate()); + } + + @Test + void testValidateSuccessful() { + config.setHostFile("/path/to/hosts"); + config.setNotifyUrl("http://example.com/notify"); + config.setMonitored(true); + + assertDoesNotThrow(() -> config.validate()); + } + + @Test + void testValidateWithTranco() { + config.setTranco(1000); + assertDoesNotThrow(() -> config.validate()); + } + + @Test + void testValidateWithCrux() { + config.setCrux(CruxListNumber.TOP_5K); + assertDoesNotThrow(() -> config.validate()); + } + + @Test + void testValidateWithTrancoEmail() { + config.setTrancoEmail(500); + assertDoesNotThrow(() -> config.validate()); + } + + @Test + void testGetTargetListProviderHostFile() { + config.setHostFile("/path/to/hosts"); + ITargetListProvider provider = config.getTargetListProvider(); + assertInstanceOf(TargetFileProvider.class, provider); + } + + @Test + void testGetTargetListProviderTrancoEmail() { + config.setTrancoEmail(500); + ITargetListProvider provider = config.getTargetListProvider(); + assertInstanceOf(TrancoEmailListProvider.class, provider); + } + + @Test + void testGetTargetListProviderCrux() { + config.setCrux(CruxListNumber.TOP_10K); + ITargetListProvider provider = config.getTargetListProvider(); + assertInstanceOf(CruxListProvider.class, provider); + } + + @Test + void testGetTargetListProviderTranco() { + config.setTranco(1000); + ITargetListProvider provider = config.getTargetListProvider(); + assertInstanceOf(TrancoListProvider.class, provider); + } + + @Test + void testCreateBulkScan() { + config.setScanName("test-scan"); + config.setMonitored(true); + config.setNotifyUrl("http://example.com/notify"); + + BulkScan bulkScan = config.createBulkScan(); + + assertEquals("test-scan", bulkScan.getName()); + assertTrue(bulkScan.isMonitored()); + assertEquals("http://example.com/notify", bulkScan.getNotifyUrl()); + assertNotNull(bulkScan.getScannerVersion()); + assertNotNull(bulkScan.getCrawlerVersion()); + assertNotNull(bulkScan.getScanConfig()); + assertTrue(bulkScan.getStartTime() > 0); + } + + @Test + void testPositiveIntegerValidator() { + ControllerCommandConfig.PositiveInteger validator = + new ControllerCommandConfig.PositiveInteger(); + + assertDoesNotThrow(() -> validator.validate("test", "0")); + assertDoesNotThrow(() -> validator.validate("test", "100")); + assertThrows(ParameterException.class, () -> validator.validate("test", "-1")); + } + + @Test + void testCronSyntaxValidator() { + ControllerCommandConfig.CronSyntax validator = new ControllerCommandConfig.CronSyntax(); + + assertDoesNotThrow(() -> validator.validate("test", "0 0 * * *")); + assertDoesNotThrow(() -> validator.validate("test", "0 */5 * * *")); + assertThrows(Exception.class, () -> validator.validate("test", "invalid cron")); + } + + @Test + void testJCommanderParsing() { + String[] args = { + "-portToBeScanned", + "8443", + "-scanDetail", + "DETAILED", + "-timeout", + "5000", + "-reexecutions", + "5", + "-scanCronInterval", + "0 0 * * *", + "-scanName", + "my-scan", + "-hostFile", + "/path/to/hosts", + "-denylist", + "/path/to/denylist", + "-monitorScan", + "-notifyUrl", + "http://example.com/notify", + "-tranco", + "1000" + }; + + JCommander jCommander = JCommander.newBuilder().addObject(config).build(); + jCommander.parse(args); + + assertEquals(8443, config.getPort()); + assertEquals(ScannerDetail.DETAILED, config.getScanDetail()); + assertEquals(5000, config.getScannerTimeout()); + assertEquals(5, config.getReexecutions()); + assertEquals("0 0 * * *", config.getScanCronInterval()); + assertEquals("my-scan", config.getScanName()); + assertEquals("/path/to/hosts", config.getHostFile()); + assertEquals("/path/to/denylist", config.getDenylistFile()); + assertTrue(config.isMonitored()); + assertEquals("http://example.com/notify", config.getNotifyUrl()); + assertEquals(1000, config.getTranco()); + } + + @Test + void testJCommanderParsingNegativeTimeout() { + String[] args = { + "-hostFile", "/path/to/hosts", + "-timeout", "-1000" + }; + + JCommander jCommander = JCommander.newBuilder().addObject(config).build(); + + assertThrows(ParameterException.class, () -> jCommander.parse(args)); + } + + @Test + void testJCommanderParsingInvalidCron() { + String[] args = { + "-hostFile", "/path/to/hosts", + "-scanCronInterval", "invalid cron expression" + }; + + JCommander jCommander = JCommander.newBuilder().addObject(config).build(); + + assertThrows(Exception.class, () -> jCommander.parse(args)); + } + + @Test + void testGetCrawlerClassForVersion() { + assertEquals(TestControllerCommandConfig.class, config.getCrawlerClassForVersion()); + } + + @Test + void testGetScannerClassForVersion() { + assertEquals(String.class, config.getScannerClassForVersion()); + } +} diff --git a/src/test/java/de/rub/nds/crawler/config/WorkerCommandConfigTest.java b/src/test/java/de/rub/nds/crawler/config/WorkerCommandConfigTest.java new file mode 100644 index 0000000..f4256b9 --- /dev/null +++ b/src/test/java/de/rub/nds/crawler/config/WorkerCommandConfigTest.java @@ -0,0 +1,95 @@ +/* + * TLS-Crawler - A TLS scanning tool to perform large scale scans with the TLS-Scanner + * + * Copyright 2018-2022 Ruhr University Bochum, Paderborn University, and Hackmanit GmbH + * + * Licensed under Apache License, Version 2.0 + * http://www.apache.org/licenses/LICENSE-2.0.txt + */ +package de.rub.nds.crawler.config; + +import static org.junit.jupiter.api.Assertions.*; + +import com.beust.jcommander.JCommander; +import de.rub.nds.crawler.config.delegate.MongoDbDelegate; +import de.rub.nds.crawler.config.delegate.RabbitMqDelegate; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +class WorkerCommandConfigTest { + + private WorkerCommandConfig config; + + @BeforeEach + void setUp() { + config = new WorkerCommandConfig(); + } + + @Test + void testDefaultValues() { + assertEquals(Runtime.getRuntime().availableProcessors(), config.getParallelScanThreads()); + assertEquals(20, config.getParallelConnectionThreads()); + assertEquals(840000, config.getScanTimeout()); + assertNotNull(config.getRabbitMqDelegate()); + assertNotNull(config.getMongoDbDelegate()); + } + + @Test + void testSettersAndGetters() { + config.setParallelScanThreads(10); + assertEquals(10, config.getParallelScanThreads()); + + config.setParallelConnectionThreads(30); + assertEquals(30, config.getParallelConnectionThreads()); + + config.setScanTimeout(1000000); + assertEquals(1000000, config.getScanTimeout()); + } + + @Test + void testJCommanderParsing() { + String[] args = { + "-numberOfThreads", "8", + "-parallelProbeThreads", "25", + "-scanTimeout", "900000" + }; + + JCommander jCommander = JCommander.newBuilder().addObject(config).build(); + jCommander.parse(args); + + assertEquals(8, config.getParallelScanThreads()); + assertEquals(25, config.getParallelConnectionThreads()); + assertEquals(900000, config.getScanTimeout()); + } + + @Test + void testJCommanderParsingWithDelegates() { + String[] args = { + "-numberOfThreads", "4", + "-rabbitMqHost", "rabbitmq.example.com", + "-rabbitMqPort", "5673", + "-mongoDbHost", "mongo.example.com", + "-mongoDbPort", "27018" + }; + + JCommander jCommander = JCommander.newBuilder().addObject(config).build(); + jCommander.parse(args); + + assertEquals(4, config.getParallelScanThreads()); + assertEquals("rabbitmq.example.com", config.getRabbitMqDelegate().getRabbitMqHost()); + assertEquals(5673, config.getRabbitMqDelegate().getRabbitMqPort()); + assertEquals("mongo.example.com", config.getMongoDbDelegate().getMongoDbHost()); + assertEquals(27018, config.getMongoDbDelegate().getMongoDbPort()); + } + + @Test + void testDelegatesNotNull() { + RabbitMqDelegate rabbitMqDelegate = config.getRabbitMqDelegate(); + MongoDbDelegate mongoDbDelegate = config.getMongoDbDelegate(); + + assertNotNull(rabbitMqDelegate); + assertNotNull(mongoDbDelegate); + assertSame(rabbitMqDelegate, config.getRabbitMqDelegate()); + assertSame(mongoDbDelegate, config.getMongoDbDelegate()); + } +} diff --git a/src/test/java/de/rub/nds/crawler/config/delegate/MongoDbDelegateTest.java b/src/test/java/de/rub/nds/crawler/config/delegate/MongoDbDelegateTest.java new file mode 100644 index 0000000..d50873d --- /dev/null +++ b/src/test/java/de/rub/nds/crawler/config/delegate/MongoDbDelegateTest.java @@ -0,0 +1,135 @@ +/* + * TLS-Crawler - A TLS scanning tool to perform large scale scans with the TLS-Scanner + * + * Copyright 2018-2022 Ruhr University Bochum, Paderborn University, and Hackmanit GmbH + * + * Licensed under Apache License, Version 2.0 + * http://www.apache.org/licenses/LICENSE-2.0.txt + */ +package de.rub.nds.crawler.config.delegate; + +import static org.junit.jupiter.api.Assertions.*; + +import com.beust.jcommander.JCommander; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +class MongoDbDelegateTest { + + private MongoDbDelegate delegate; + + @BeforeEach + void setUp() { + delegate = new MongoDbDelegate(); + } + + @Test + void testDefaultValues() { + assertNull(delegate.getMongoDbHost()); + assertEquals(0, delegate.getMongoDbPort()); + assertNull(delegate.getMongoDbUser()); + assertNull(delegate.getMongoDbPass()); + assertNull(delegate.getMongoDbPassFile()); + assertNull(delegate.getMongoDbAuthSource()); + } + + @Test + void testSettersAndGetters() { + delegate.setMongoDbHost("localhost"); + assertEquals("localhost", delegate.getMongoDbHost()); + + delegate.setMongoDbPort(27017); + assertEquals(27017, delegate.getMongoDbPort()); + + delegate.setMongoDbUser("user"); + assertEquals("user", delegate.getMongoDbUser()); + + delegate.setMongoDbPass("pass"); + assertEquals("pass", delegate.getMongoDbPass()); + + delegate.setMongoDbPassFile("/path/to/pass"); + assertEquals("/path/to/pass", delegate.getMongoDbPassFile()); + + delegate.setMongoDbAuthSource("admin"); + assertEquals("admin", delegate.getMongoDbAuthSource()); + } + + @Test + void testJCommanderParsing() { + String[] args = { + "-mongoDbHost", "mongo.example.com", + "-mongoDbPort", "27018", + "-mongoDbUser", "testuser", + "-mongoDbPass", "testpass", + "-mongoDbPassFile", "/etc/mongodb/pass", + "-mongoDbAuthSource", "testdb" + }; + + JCommander jCommander = JCommander.newBuilder().addObject(delegate).build(); + jCommander.parse(args); + + assertEquals("mongo.example.com", delegate.getMongoDbHost()); + assertEquals(27018, delegate.getMongoDbPort()); + assertEquals("testuser", delegate.getMongoDbUser()); + assertEquals("testpass", delegate.getMongoDbPass()); + assertEquals("/etc/mongodb/pass", delegate.getMongoDbPassFile()); + assertEquals("testdb", delegate.getMongoDbAuthSource()); + } + + @Test + void testJCommanderParsingPartial() { + String[] args = { + "-mongoDbHost", "localhost", + "-mongoDbPort", "27017", + "-mongoDbAuthSource", "admin" + }; + + JCommander jCommander = JCommander.newBuilder().addObject(delegate).build(); + jCommander.parse(args); + + assertEquals("localhost", delegate.getMongoDbHost()); + assertEquals(27017, delegate.getMongoDbPort()); + assertNull(delegate.getMongoDbUser()); + assertNull(delegate.getMongoDbPass()); + assertNull(delegate.getMongoDbPassFile()); + assertEquals("admin", delegate.getMongoDbAuthSource()); + } + + @Test + void testJCommanderParsingEmpty() { + String[] args = {}; + + JCommander jCommander = JCommander.newBuilder().addObject(delegate).build(); + jCommander.parse(args); + + assertNull(delegate.getMongoDbHost()); + assertEquals(0, delegate.getMongoDbPort()); + assertNull(delegate.getMongoDbUser()); + assertNull(delegate.getMongoDbPass()); + assertNull(delegate.getMongoDbPassFile()); + assertNull(delegate.getMongoDbAuthSource()); + } + + @Test + void testPasswordAndPasswordFile() { + // Test that both password and password file can be set + delegate.setMongoDbPass("directpass"); + delegate.setMongoDbPassFile("/path/to/passfile"); + + assertEquals("directpass", delegate.getMongoDbPass()); + assertEquals("/path/to/passfile", delegate.getMongoDbPassFile()); + } + + @Test + void testAuthSourceConfiguration() { + // Test different auth source configurations + delegate.setMongoDbAuthSource("admin"); + assertEquals("admin", delegate.getMongoDbAuthSource()); + + delegate.setMongoDbAuthSource("myappdb"); + assertEquals("myappdb", delegate.getMongoDbAuthSource()); + + delegate.setMongoDbAuthSource("$external"); + assertEquals("$external", delegate.getMongoDbAuthSource()); + } +} diff --git a/src/test/java/de/rub/nds/crawler/config/delegate/RabbitMqDelegateTest.java b/src/test/java/de/rub/nds/crawler/config/delegate/RabbitMqDelegateTest.java new file mode 100644 index 0000000..f479c58 --- /dev/null +++ b/src/test/java/de/rub/nds/crawler/config/delegate/RabbitMqDelegateTest.java @@ -0,0 +1,121 @@ +/* + * TLS-Crawler - A TLS scanning tool to perform large scale scans with the TLS-Scanner + * + * Copyright 2018-2022 Ruhr University Bochum, Paderborn University, and Hackmanit GmbH + * + * Licensed under Apache License, Version 2.0 + * http://www.apache.org/licenses/LICENSE-2.0.txt + */ +package de.rub.nds.crawler.config.delegate; + +import static org.junit.jupiter.api.Assertions.*; + +import com.beust.jcommander.JCommander; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +class RabbitMqDelegateTest { + + private RabbitMqDelegate delegate; + + @BeforeEach + void setUp() { + delegate = new RabbitMqDelegate(); + } + + @Test + void testDefaultValues() { + assertNull(delegate.getRabbitMqHost()); + assertEquals(0, delegate.getRabbitMqPort()); + assertNull(delegate.getRabbitMqUser()); + assertNull(delegate.getRabbitMqPass()); + assertNull(delegate.getRabbitMqPassFile()); + assertFalse(delegate.isRabbitMqTLS()); + } + + @Test + void testSettersAndGetters() { + delegate.setRabbitMqHost("localhost"); + assertEquals("localhost", delegate.getRabbitMqHost()); + + delegate.setRabbitMqPort(5672); + assertEquals(5672, delegate.getRabbitMqPort()); + + delegate.setRabbitMqUser("user"); + assertEquals("user", delegate.getRabbitMqUser()); + + delegate.setRabbitMqPass("pass"); + assertEquals("pass", delegate.getRabbitMqPass()); + + delegate.setRabbitMqPassFile("/path/to/pass"); + assertEquals("/path/to/pass", delegate.getRabbitMqPassFile()); + + delegate.setRabbitMqTLS(true); + assertTrue(delegate.isRabbitMqTLS()); + } + + @Test + void testJCommanderParsing() { + String[] args = { + "-rabbitMqHost", "rabbitmq.example.com", + "-rabbitMqPort", "5673", + "-rabbitMqUser", "testuser", + "-rabbitMqPass", "testpass", + "-rabbitMqPassFile", "/etc/rabbitmq/pass", + "-rabbitMqTLS" + }; + + JCommander jCommander = JCommander.newBuilder().addObject(delegate).build(); + jCommander.parse(args); + + assertEquals("rabbitmq.example.com", delegate.getRabbitMqHost()); + assertEquals(5673, delegate.getRabbitMqPort()); + assertEquals("testuser", delegate.getRabbitMqUser()); + assertEquals("testpass", delegate.getRabbitMqPass()); + assertEquals("/etc/rabbitmq/pass", delegate.getRabbitMqPassFile()); + assertTrue(delegate.isRabbitMqTLS()); + } + + @Test + void testJCommanderParsingPartial() { + String[] args = { + "-rabbitMqHost", "localhost", + "-rabbitMqPort", "5672" + }; + + JCommander jCommander = JCommander.newBuilder().addObject(delegate).build(); + jCommander.parse(args); + + assertEquals("localhost", delegate.getRabbitMqHost()); + assertEquals(5672, delegate.getRabbitMqPort()); + assertNull(delegate.getRabbitMqUser()); + assertNull(delegate.getRabbitMqPass()); + assertNull(delegate.getRabbitMqPassFile()); + assertFalse(delegate.isRabbitMqTLS()); + } + + @Test + void testJCommanderParsingEmpty() { + String[] args = {}; + + JCommander jCommander = JCommander.newBuilder().addObject(delegate).build(); + jCommander.parse(args); + + assertNull(delegate.getRabbitMqHost()); + assertEquals(0, delegate.getRabbitMqPort()); + assertNull(delegate.getRabbitMqUser()); + assertNull(delegate.getRabbitMqPass()); + assertNull(delegate.getRabbitMqPassFile()); + assertFalse(delegate.isRabbitMqTLS()); + } + + @Test + void testPasswordAndPasswordFile() { + // Test that both password and password file can be set + delegate.setRabbitMqPass("directpass"); + delegate.setRabbitMqPassFile("/path/to/passfile"); + + assertEquals("directpass", delegate.getRabbitMqPass()); + assertEquals("/path/to/passfile", delegate.getRabbitMqPassFile()); + } +} diff --git a/src/test/java/de/rub/nds/crawler/constant/JobStatusTest.java b/src/test/java/de/rub/nds/crawler/constant/JobStatusTest.java new file mode 100644 index 0000000..34505cd --- /dev/null +++ b/src/test/java/de/rub/nds/crawler/constant/JobStatusTest.java @@ -0,0 +1,141 @@ +/* + * TLS-Crawler - A TLS scanning tool to perform large scale scans with the TLS-Scanner + * + * Copyright 2018-2023 Ruhr University Bochum, Paderborn University, and Hackmanit GmbH + * + * Licensed under Apache License, Version 2.0 + * http://www.apache.org/licenses/LICENSE-2.0.txt + */ +package de.rub.nds.crawler.constant; + +import static org.junit.jupiter.api.Assertions.*; + +import org.junit.jupiter.api.*; + +class JobStatusTest { + + @Test + void testErrorStatuses() { + // Test all statuses marked as errors + assertTrue(JobStatus.UNRESOLVABLE.isError()); + assertTrue(JobStatus.RESOLUTION_ERROR.isError()); + assertTrue(JobStatus.DENYLISTED.isError()); + assertTrue(JobStatus.ERROR.isError()); + assertTrue(JobStatus.SERIALIZATION_ERROR.isError()); + assertTrue(JobStatus.CANCELLED.isError()); + assertTrue(JobStatus.INTERNAL_ERROR.isError()); + assertTrue(JobStatus.CRAWLER_ERROR.isError()); + } + + @Test + void testNonErrorStatuses() { + // Test all statuses not marked as errors + assertFalse(JobStatus.TO_BE_EXECUTED.isError()); + assertFalse(JobStatus.SUCCESS.isError()); + assertFalse(JobStatus.EMPTY.isError()); + } + + @Test + void testAllStatusesHaveErrorFlag() { + // Ensure all enum values have their error flag properly set + for (JobStatus status : JobStatus.values()) { + // This should not throw - all statuses should have isError defined + boolean isError = status.isError(); + + // Verify the value is either true or false (not null or undefined) + assertTrue(isError || !isError); + } + } + + @Test + void testEnumValueOf() { + // Test that valueOf works for all enum constants + assertEquals(JobStatus.TO_BE_EXECUTED, JobStatus.valueOf("TO_BE_EXECUTED")); + assertEquals(JobStatus.UNRESOLVABLE, JobStatus.valueOf("UNRESOLVABLE")); + assertEquals(JobStatus.RESOLUTION_ERROR, JobStatus.valueOf("RESOLUTION_ERROR")); + assertEquals(JobStatus.DENYLISTED, JobStatus.valueOf("DENYLISTED")); + assertEquals(JobStatus.SUCCESS, JobStatus.valueOf("SUCCESS")); + assertEquals(JobStatus.EMPTY, JobStatus.valueOf("EMPTY")); + assertEquals(JobStatus.ERROR, JobStatus.valueOf("ERROR")); + assertEquals(JobStatus.SERIALIZATION_ERROR, JobStatus.valueOf("SERIALIZATION_ERROR")); + assertEquals(JobStatus.CANCELLED, JobStatus.valueOf("CANCELLED")); + assertEquals(JobStatus.INTERNAL_ERROR, JobStatus.valueOf("INTERNAL_ERROR")); + assertEquals(JobStatus.CRAWLER_ERROR, JobStatus.valueOf("CRAWLER_ERROR")); + } + + @Test + void testEnumValues() { + // Test that values() returns all enum constants + JobStatus[] statuses = JobStatus.values(); + assertEquals(11, statuses.length); + + // Verify all expected values are present + boolean hasToBeExecuted = false; + boolean hasUnresolvable = false; + boolean hasResolutionError = false; + boolean hasDenylisted = false; + boolean hasSuccess = false; + boolean hasEmpty = false; + boolean hasError = false; + boolean hasSerializationError = false; + boolean hasCancelled = false; + boolean hasInternalError = false; + boolean hasCrawlerError = false; + + for (JobStatus status : statuses) { + switch (status) { + case TO_BE_EXECUTED: + hasToBeExecuted = true; + break; + case UNRESOLVABLE: + hasUnresolvable = true; + break; + case RESOLUTION_ERROR: + hasResolutionError = true; + break; + case DENYLISTED: + hasDenylisted = true; + break; + case SUCCESS: + hasSuccess = true; + break; + case EMPTY: + hasEmpty = true; + break; + case ERROR: + hasError = true; + break; + case SERIALIZATION_ERROR: + hasSerializationError = true; + break; + case CANCELLED: + hasCancelled = true; + break; + case INTERNAL_ERROR: + hasInternalError = true; + break; + case CRAWLER_ERROR: + hasCrawlerError = true; + break; + } + } + + assertTrue(hasToBeExecuted); + assertTrue(hasUnresolvable); + assertTrue(hasResolutionError); + assertTrue(hasDenylisted); + assertTrue(hasSuccess); + assertTrue(hasEmpty); + assertTrue(hasError); + assertTrue(hasSerializationError); + assertTrue(hasCancelled); + assertTrue(hasInternalError); + assertTrue(hasCrawlerError); + } + + @Test + void testInvalidValueOf() { + // Test that valueOf throws for invalid values + assertThrows(IllegalArgumentException.class, () -> JobStatus.valueOf("INVALID_STATUS")); + } +} diff --git a/src/test/java/de/rub/nds/crawler/core/BulkScanWorkerManagerTest.java b/src/test/java/de/rub/nds/crawler/core/BulkScanWorkerManagerTest.java new file mode 100644 index 0000000..b73f33b --- /dev/null +++ b/src/test/java/de/rub/nds/crawler/core/BulkScanWorkerManagerTest.java @@ -0,0 +1,246 @@ +/* + * TLS-Crawler - A TLS scanning tool to perform large scale scans with the TLS-Scanner + * + * Copyright 2018-2023 Ruhr University Bochum, Paderborn University, and Hackmanit GmbH + * + * Licensed under Apache License, Version 2.0 + * http://www.apache.org/licenses/LICENSE-2.0.txt + */ +package de.rub.nds.crawler.core; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; + +import de.rub.nds.crawler.data.*; +import java.lang.reflect.Field; +import java.util.concurrent.*; +import org.apache.commons.lang3.exception.UncheckedException; +import org.bson.Document; +import org.junit.jupiter.api.*; +import org.mockito.*; + +class BulkScanWorkerManagerTest { + + private BulkScanWorker mockBulkScanWorker; + @Mock private ScanConfig mockScanConfig; + @Mock private Future mockFuture; + + private BulkScanWorkerManager manager; + + @BeforeEach + void setUp() throws Exception { + MockitoAnnotations.openMocks(this); + mockBulkScanWorker = mock(BulkScanWorker.class); + + // Reset singleton instance + Field instanceField = BulkScanWorkerManager.class.getDeclaredField("instance"); + instanceField.setAccessible(true); + instanceField.set(null, null); + + manager = BulkScanWorkerManager.getInstance(); + } + + @AfterEach + void tearDown() throws Exception { + // Reset singleton instance after each test + Field instanceField = BulkScanWorkerManager.class.getDeclaredField("instance"); + instanceField.setAccessible(true); + instanceField.set(null, null); + } + + @Test + void testGetInstance() { + BulkScanWorkerManager instance1 = BulkScanWorkerManager.getInstance(); + BulkScanWorkerManager instance2 = BulkScanWorkerManager.getInstance(); + assertSame(instance1, instance2); + } + + @Test + void testGetBulkScanWorkerCreatesNewWorker() { + // Given + String bulkScanId = "test-scan-id"; + int parallelConnectionThreads = 10; + int parallelScanThreads = 2; + + // Use doReturn to avoid generic issues + mockBulkScanWorker = mock(BulkScanWorker.class); + doReturn(mockBulkScanWorker) + .when(mockScanConfig) + .createWorker(bulkScanId, parallelConnectionThreads, parallelScanThreads); + + // When + BulkScanWorker worker = + manager.getBulkScanWorker( + bulkScanId, mockScanConfig, parallelConnectionThreads, parallelScanThreads); + + // Then + assertSame(mockBulkScanWorker, worker); + verify(mockBulkScanWorker).init(); + verify(mockScanConfig) + .createWorker(bulkScanId, parallelConnectionThreads, parallelScanThreads); + } + + @Test + void testGetBulkScanWorkerReturnsCachedWorker() { + // Given + String bulkScanId = "test-scan-id"; + int parallelConnectionThreads = 10; + int parallelScanThreads = 2; + + // Use doReturn to avoid generic issues + mockBulkScanWorker = mock(BulkScanWorker.class); + doReturn(mockBulkScanWorker) + .when(mockScanConfig) + .createWorker(bulkScanId, parallelConnectionThreads, parallelScanThreads); + + // When - get worker twice + BulkScanWorker worker1 = + manager.getBulkScanWorker( + bulkScanId, mockScanConfig, parallelConnectionThreads, parallelScanThreads); + BulkScanWorker worker2 = + manager.getBulkScanWorker( + bulkScanId, mockScanConfig, parallelConnectionThreads, parallelScanThreads); + + // Then + assertSame(worker1, worker2); + assertSame(mockBulkScanWorker, worker1); + // Should only create and init once + verify(mockScanConfig, times(1)) + .createWorker(bulkScanId, parallelConnectionThreads, parallelScanThreads); + verify(mockBulkScanWorker, times(1)).init(); + } + + @Test + void testGetBulkScanWorkerThrowsExceptionOnCreationFailure() { + // Given + String bulkScanId = "test-scan-id"; + int parallelConnectionThreads = 10; + int parallelScanThreads = 2; + + when(mockScanConfig.createWorker( + bulkScanId, parallelConnectionThreads, parallelScanThreads)) + .thenThrow(new RuntimeException("Creation failed")); + + // When/Then + assertThrows( + UncheckedException.class, + () -> + manager.getBulkScanWorker( + bulkScanId, + mockScanConfig, + parallelConnectionThreads, + parallelScanThreads)); + } + + @Test + void testHandle() { + // Given + String bulkScanId = "test-scan-id"; + ScanTarget scanTarget = new ScanTarget(); + scanTarget.setHostname("example.com"); + scanTarget.setPort(443); + scanTarget.setIp("192.0.2.1"); + BulkScanInfo bulkScanInfo = mock(BulkScanInfo.class); + ScanJobDescription scanJobDescription = mock(ScanJobDescription.class); + + when(scanJobDescription.getBulkScanInfo()).thenReturn(bulkScanInfo); + when(scanJobDescription.getScanTarget()).thenReturn(scanTarget); + when(bulkScanInfo.getBulkScanId()).thenReturn(bulkScanId); + when(bulkScanInfo.getScanConfig()).thenReturn(mockScanConfig); + mockBulkScanWorker = mock(BulkScanWorker.class); + doReturn(mockBulkScanWorker) + .when(mockScanConfig) + .createWorker(anyString(), anyInt(), anyInt()); + when(mockBulkScanWorker.handle(scanTarget)).thenReturn(mockFuture); + + // When + Future result = manager.handle(scanJobDescription, 10, 2); + + // Then + assertSame(mockFuture, result); + verify(mockBulkScanWorker).handle(scanTarget); + } + + @Test + void testHandleStatic() { + // Given + String bulkScanId = "test-scan-id"; + ScanTarget scanTarget = new ScanTarget(); + scanTarget.setHostname("example.com"); + scanTarget.setPort(443); + scanTarget.setIp("192.0.2.1"); + BulkScanInfo bulkScanInfo = mock(BulkScanInfo.class); + ScanJobDescription scanJobDescription = mock(ScanJobDescription.class); + + when(scanJobDescription.getBulkScanInfo()).thenReturn(bulkScanInfo); + when(scanJobDescription.getScanTarget()).thenReturn(scanTarget); + when(bulkScanInfo.getBulkScanId()).thenReturn(bulkScanId); + when(bulkScanInfo.getScanConfig()).thenReturn(mockScanConfig); + mockBulkScanWorker = mock(BulkScanWorker.class); + doReturn(mockBulkScanWorker) + .when(mockScanConfig) + .createWorker(anyString(), anyInt(), anyInt()); + when(mockBulkScanWorker.handle(scanTarget)).thenReturn(mockFuture); + + // When + Future result = BulkScanWorkerManager.handleStatic(scanJobDescription, 10, 2); + + // Then + assertSame(mockFuture, result); + verify(mockBulkScanWorker).handle(scanTarget); + } + + @Test + void testCacheEvictionCallsCleanup() throws Exception { + // Given + String bulkScanId = "test-scan-id"; + + mockBulkScanWorker = mock(BulkScanWorker.class); + doReturn(mockBulkScanWorker) + .when(mockScanConfig) + .createWorker(anyString(), anyInt(), anyInt()); + + // Get the cache field via reflection + Field cacheField = BulkScanWorkerManager.class.getDeclaredField("bulkScanWorkers"); + cacheField.setAccessible(true); + @SuppressWarnings("unchecked") + com.google.common.cache.Cache> cache = + (com.google.common.cache.Cache>) cacheField.get(manager); + + // When - add worker and then invalidate + manager.getBulkScanWorker(bulkScanId, mockScanConfig, 10, 2); + cache.invalidate(bulkScanId); + + // Give some time for async removal listener + Thread.sleep(100); + + // Then + verify(mockBulkScanWorker).cleanup(); + } + + @Test + void testMultipleBulkScansHaveSeparateWorkers() { + // Given + String bulkScanId1 = "scan-1"; + String bulkScanId2 = "scan-2"; + + // Create separate mocks for each worker + BulkScanWorker worker1 = mock(BulkScanWorker.class); + BulkScanWorker worker2 = mock(BulkScanWorker.class); + + doReturn(worker1).when(mockScanConfig).createWorker(eq(bulkScanId1), anyInt(), anyInt()); + doReturn(worker2).when(mockScanConfig).createWorker(eq(bulkScanId2), anyInt(), anyInt()); + + // When + BulkScanWorker result1 = manager.getBulkScanWorker(bulkScanId1, mockScanConfig, 10, 2); + BulkScanWorker result2 = manager.getBulkScanWorker(bulkScanId2, mockScanConfig, 10, 2); + + // Then + assertNotSame(result1, result2); + assertSame(worker1, result1); + assertSame(worker2, result2); + verify(worker1).init(); + verify(worker2).init(); + } +} diff --git a/src/test/java/de/rub/nds/crawler/core/ControllerTest.java b/src/test/java/de/rub/nds/crawler/core/ControllerTest.java index afddf0f..48a5774 100644 --- a/src/test/java/de/rub/nds/crawler/core/ControllerTest.java +++ b/src/test/java/de/rub/nds/crawler/core/ControllerTest.java @@ -8,18 +8,46 @@ */ package de.rub.nds.crawler.core; +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; + import de.rub.nds.crawler.config.ControllerCommandConfig; +import de.rub.nds.crawler.denylist.IDenylistProvider; import de.rub.nds.crawler.dummy.DummyControllerCommandConfig; import de.rub.nds.crawler.dummy.DummyOrchestrationProvider; import de.rub.nds.crawler.dummy.DummyPersistenceProvider; +import de.rub.nds.crawler.orchestration.IOrchestrationProvider; +import de.rub.nds.crawler.persistence.IPersistenceProvider; +import de.rub.nds.crawler.targetlist.ITargetListProvider; import java.io.File; import java.io.FileWriter; import java.io.IOException; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; +import java.lang.reflect.Field; +import java.util.HashSet; +import java.util.Set; +import org.junit.jupiter.api.*; +import org.mockito.*; +import org.quartz.*; +import org.quartz.impl.matchers.GroupMatcher; class ControllerTest { + @Mock private ControllerCommandConfig mockConfig; + @Mock private IOrchestrationProvider mockOrchestrationProvider; + @Mock private IPersistenceProvider mockPersistenceProvider; + @Mock private ITargetListProvider mockTargetListProvider; + @Mock private Scheduler mockScheduler; + @Mock private ListenerManager mockListenerManager; + @Mock private Trigger mockTrigger; + + private Controller controller; + + @BeforeEach + void setUp() { + MockitoAnnotations.openMocks(this); + } + @Test void submitting() throws IOException, InterruptedException { var persistenceProvider = new DummyPersistenceProvider(); @@ -43,4 +71,163 @@ void submitting() throws IOException, InterruptedException { Assertions.assertEquals(2, orchestrationProvider.jobQueue.size()); Assertions.assertEquals(0, orchestrationProvider.unackedJobs.size()); } + + @Test + void testConstructorWithoutDenylist() { + when(mockConfig.getDenylistFile()).thenReturn(null); + + controller = new Controller(mockConfig, mockOrchestrationProvider, mockPersistenceProvider); + + assertNotNull(controller); + verify(mockConfig).getDenylistFile(); + } + + @Test + void testConstructorWithDenylist() throws Exception { + File denylistFile = File.createTempFile("denylist", ".txt"); + denylistFile.deleteOnExit(); + + when(mockConfig.getDenylistFile()).thenReturn(denylistFile.getAbsolutePath()); + + controller = new Controller(mockConfig, mockOrchestrationProvider, mockPersistenceProvider); + + // Verify denylist provider was created via reflection + Field denylistField = Controller.class.getDeclaredField("denylistProvider"); + denylistField.setAccessible(true); + IDenylistProvider denylistProvider = (IDenylistProvider) denylistField.get(controller); + + assertNotNull(denylistProvider); + } + + @Test + void testStartWithCronSchedule() throws SchedulerException { + // Setup + when(mockConfig.getTargetListProvider()).thenReturn(mockTargetListProvider); + when(mockConfig.getScanCronInterval()).thenReturn("0 0 12 * * ?"); + when(mockConfig.isMonitored()).thenReturn(false); + when(mockConfig.getDenylistFile()).thenReturn(null); + + // Use spy to intercept scheduler creation + Controller controllerSpy = + spy(new Controller(mockConfig, mockOrchestrationProvider, mockPersistenceProvider)); + + // We can't easily mock the StdSchedulerFactory, so let's test what we can + assertDoesNotThrow(() -> controllerSpy.start()); + + verify(mockConfig).getTargetListProvider(); + verify(mockConfig, atLeastOnce()).getScanCronInterval(); + verify(mockConfig).isMonitored(); + } + + @Test + void testStartWithSimpleSchedule() { + // Setup + when(mockConfig.getTargetListProvider()).thenReturn(mockTargetListProvider); + when(mockConfig.getScanCronInterval()).thenReturn(null); + when(mockConfig.isMonitored()).thenReturn(false); + when(mockConfig.getDenylistFile()).thenReturn(null); + + Controller controller = + new Controller(mockConfig, mockOrchestrationProvider, mockPersistenceProvider); + + assertDoesNotThrow(() -> controller.start()); + + verify(mockConfig).getScanCronInterval(); + } + + @Test + void testStartWithMonitoring() { + // Setup + when(mockConfig.getTargetListProvider()).thenReturn(mockTargetListProvider); + when(mockConfig.getScanCronInterval()).thenReturn(null); + when(mockConfig.isMonitored()).thenReturn(true); + when(mockConfig.getDenylistFile()).thenReturn(null); + + Controller controller = + new Controller(mockConfig, mockOrchestrationProvider, mockPersistenceProvider); + + assertDoesNotThrow(() -> controller.start()); + + verify(mockConfig).isMonitored(); + } + + @Test + void testShutdownSchedulerIfAllTriggersFinalized_NoTriggers() throws SchedulerException { + // Setup + Set triggerKeys = new HashSet<>(); + when(mockScheduler.getTriggerKeys(any(GroupMatcher.class))).thenReturn(triggerKeys); + + // When + Controller.shutdownSchedulerIfAllTriggersFinalized(mockScheduler); + + // Then + verify(mockScheduler).shutdown(); + } + + @Test + void testShutdownSchedulerIfAllTriggersFinalized_WithActiveTrigger() throws SchedulerException { + // Setup + TriggerKey triggerKey = new TriggerKey("trigger1", "group1"); + Set triggerKeys = new HashSet<>(); + triggerKeys.add(triggerKey); + + when(mockScheduler.getTriggerKeys(any(GroupMatcher.class))).thenReturn(triggerKeys); + when(mockScheduler.getTrigger(triggerKey)).thenReturn(mockTrigger); + when(mockTrigger.mayFireAgain()).thenReturn(true); + + // When + Controller.shutdownSchedulerIfAllTriggersFinalized(mockScheduler); + + // Then + verify(mockScheduler, never()).shutdown(); + } + + @Test + void testShutdownSchedulerIfAllTriggersFinalized_WithInactiveTrigger() + throws SchedulerException { + // Setup + TriggerKey triggerKey = new TriggerKey("trigger1", "group1"); + Set triggerKeys = new HashSet<>(); + triggerKeys.add(triggerKey); + + when(mockScheduler.getTriggerKeys(any(GroupMatcher.class))).thenReturn(triggerKeys); + when(mockScheduler.getTrigger(triggerKey)).thenReturn(mockTrigger); + when(mockTrigger.mayFireAgain()).thenReturn(false); + + // When + Controller.shutdownSchedulerIfAllTriggersFinalized(mockScheduler); + + // Then + verify(mockScheduler).shutdown(); + } + + @Test + void testShutdownSchedulerIfAllTriggersFinalized_SchedulerException() + throws SchedulerException { + // Setup + when(mockScheduler.getTriggerKeys(any(GroupMatcher.class))) + .thenThrow(new SchedulerException("Test exception")); + + // When/Then - should not throw + assertDoesNotThrow(() -> Controller.shutdownSchedulerIfAllTriggersFinalized(mockScheduler)); + } + + @Test + void testShutdownSchedulerIfAllTriggersFinalized_TriggerReadException() + throws SchedulerException { + // Setup + TriggerKey triggerKey = new TriggerKey("trigger1", "group1"); + Set triggerKeys = new HashSet<>(); + triggerKeys.add(triggerKey); + + when(mockScheduler.getTriggerKeys(any(GroupMatcher.class))).thenReturn(triggerKeys); + when(mockScheduler.getTrigger(triggerKey)) + .thenThrow(new SchedulerException("Cannot read trigger")); + + // When + Controller.shutdownSchedulerIfAllTriggersFinalized(mockScheduler); + + // Then - should not shutdown due to exception (treated as still running) + verify(mockScheduler, never()).shutdown(); + } } diff --git a/src/test/java/de/rub/nds/crawler/core/ProgressMonitorTest.java b/src/test/java/de/rub/nds/crawler/core/ProgressMonitorTest.java new file mode 100644 index 0000000..d6884c5 --- /dev/null +++ b/src/test/java/de/rub/nds/crawler/core/ProgressMonitorTest.java @@ -0,0 +1,446 @@ +/* + * TLS-Crawler - A TLS scanning tool to perform large scale scans with the TLS-Scanner + * + * Copyright 2018-2022 Ruhr University Bochum, Paderborn University, and Hackmanit GmbH + * + * Licensed under Apache License, Version 2.0 + * http://www.apache.org/licenses/LICENSE-2.0.txt + */ +package de.rub.nds.crawler.core; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; + +import de.rub.nds.crawler.constant.JobStatus; +import de.rub.nds.crawler.data.BulkScan; +import de.rub.nds.crawler.data.ScanConfig; +import de.rub.nds.crawler.data.ScanJobDescription; +import de.rub.nds.crawler.data.ScanTarget; +import de.rub.nds.crawler.orchestration.DoneNotificationConsumer; +import de.rub.nds.crawler.orchestration.IOrchestrationProvider; +import de.rub.nds.crawler.persistence.IPersistenceProvider; +import de.rub.nds.scanner.core.config.ScannerDetail; +import java.io.IOException; +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.net.http.HttpClient; +import java.net.http.HttpRequest; +import java.net.http.HttpResponse; +import java.util.Map; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockedStatic; +import org.mockito.MockitoAnnotations; +import org.quartz.Scheduler; +import org.quartz.SchedulerException; + +class ProgressMonitorTest { + + @Mock private IOrchestrationProvider orchestrationProvider; + + @Mock private IPersistenceProvider persistenceProvider; + + @Mock private Scheduler scheduler; + + @Mock private HttpClient httpClient; + + @Mock private HttpResponse httpResponse; + + private ProgressMonitor progressMonitor; + + @BeforeEach + void setUp() { + MockitoAnnotations.openMocks(this); + progressMonitor = + new ProgressMonitor(orchestrationProvider, persistenceProvider, scheduler); + } + + private BulkScan createTestBulkScan(String id, String name) { + ScanConfig scanConfig = + new ScanConfig(ScannerDetail.NORMAL, 3, 2000) { + @Override + public BulkScanWorker createWorker( + String bulkScanID, + int parallelConnectionThreads, + int parallelScanThreads) { + return null; + } + }; + BulkScan bulkScan = + new BulkScan( + ProgressMonitorTest.class, + ProgressMonitorTest.class, + name, + scanConfig, + System.currentTimeMillis(), + true, + null); + bulkScan.set_id(id); + return bulkScan; + } + + private ScanJobDescription createTestScanJob(BulkScan bulkScan, JobStatus status) { + ScanTarget target = new ScanTarget(); + target.setHostname("example.com"); + target.setIp("192.0.2.1"); + return new ScanJobDescription(target, bulkScan, status); + } + + @Test + void testStartMonitoringBulkScanProgress() throws Exception { + BulkScan bulkScan = createTestBulkScan("test-id", "test-scan"); + + progressMonitor.startMonitoringBulkScanProgress(bulkScan); + + // Verify that done notification consumer is registered + verify(orchestrationProvider) + .registerDoneNotificationConsumer( + eq(bulkScan), any(DoneNotificationConsumer.class)); + + // Check that bulk scan is tracked internally + Field scanJobDetailsByIdField = + ProgressMonitor.class.getDeclaredField("scanJobDetailsById"); + scanJobDetailsByIdField.setAccessible(true); + Map scanJobDetailsById = + (Map) scanJobDetailsByIdField.get(progressMonitor); + assertTrue(scanJobDetailsById.containsKey("test-id")); + } + + @Test + void testStartMonitoringMultipleBulkScans() throws Exception { + BulkScan bulkScan1 = createTestBulkScan("test-id-1", "test-scan-1"); + BulkScan bulkScan2 = createTestBulkScan("test-id-2", "test-scan-2"); + + progressMonitor.startMonitoringBulkScanProgress(bulkScan1); + progressMonitor.startMonitoringBulkScanProgress(bulkScan2); + + // Should only register listener once + verify(orchestrationProvider, times(1)) + .registerDoneNotificationConsumer(any(), any(DoneNotificationConsumer.class)); + + // Check that both bulk scans are tracked + Field scanJobDetailsByIdField = + ProgressMonitor.class.getDeclaredField("scanJobDetailsById"); + scanJobDetailsByIdField.setAccessible(true); + Map scanJobDetailsById = + (Map) scanJobDetailsByIdField.get(progressMonitor); + assertEquals(2, scanJobDetailsById.size()); + assertTrue(scanJobDetailsById.containsKey("test-id-1")); + assertTrue(scanJobDetailsById.containsKey("test-id-2")); + } + + @Test + void testStopMonitoringAndFinalizeBulkScanWithoutNotification() throws Exception { + // Setup + BulkScan bulkScan = createTestBulkScan("test-id", "test-scan"); + + progressMonitor.startMonitoringBulkScanProgress(bulkScan); + + // Execute + progressMonitor.stopMonitoringAndFinalizeBulkScan("test-id"); + + // Verify + verify(persistenceProvider) + .updateBulkScan( + argThat( + scan -> { + return scan.isFinished() && scan.getEndTime() > 0; + })); + + // Check that bulk scan is removed from tracking + Field scanJobDetailsByIdField = + ProgressMonitor.class.getDeclaredField("scanJobDetailsById"); + scanJobDetailsByIdField.setAccessible(true); + Map scanJobDetailsById = + (Map) scanJobDetailsByIdField.get(progressMonitor); + assertFalse(scanJobDetailsById.containsKey("test-id")); + } + + @Test + void testStopMonitoringAndFinalizeBulkScanWithNotification() throws Exception { + // Setup + BulkScan bulkScan = createTestBulkScan("test-id", "test-scan"); + bulkScan.setNotifyUrl("http://example.com/notify"); + + progressMonitor.startMonitoringBulkScanProgress(bulkScan); + + try (MockedStatic mockedHttpClient = mockStatic(HttpClient.class)) { + mockedHttpClient.when(HttpClient::newHttpClient).thenReturn(httpClient); + when(httpClient.send(any(HttpRequest.class), any(HttpResponse.BodyHandler.class))) + .thenReturn(httpResponse); + when(httpResponse.body()).thenReturn("OK"); + + // Execute + progressMonitor.stopMonitoringAndFinalizeBulkScan("test-id"); + + // Verify notification was sent + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(HttpRequest.class); + verify(httpClient).send(requestCaptor.capture(), any()); + + HttpRequest request = requestCaptor.getValue(); + assertEquals("http://example.com/notify", request.uri().toString()); + assertEquals("POST", request.method()); + } + + verify(persistenceProvider).updateBulkScan(any()); + } + + @Test + void testStopMonitoringWithSchedulerShutdown() throws Exception { + // Setup + BulkScan bulkScan = createTestBulkScan("test-id", "test-scan"); + + progressMonitor.startMonitoringBulkScanProgress(bulkScan); + when(scheduler.isShutdown()).thenReturn(true); + + // Execute + progressMonitor.stopMonitoringAndFinalizeBulkScan("test-id"); + + // Verify + verify(orchestrationProvider).closeConnection(); + } + + @Test + void testStopMonitoringWithSchedulerException() throws Exception { + // Setup + BulkScan bulkScan = createTestBulkScan("test-id", "test-scan"); + + progressMonitor.startMonitoringBulkScanProgress(bulkScan); + when(scheduler.isShutdown()).thenThrow(new SchedulerException("Test exception")); + + // Execute - should not throw + assertDoesNotThrow(() -> progressMonitor.stopMonitoringAndFinalizeBulkScan("test-id")); + } + + @Test + void testNotifyIOException() throws Exception { + // Setup + BulkScan bulkScan = createTestBulkScan("test-id", "test-scan"); + bulkScan.setNotifyUrl("http://example.com/notify"); + + progressMonitor.startMonitoringBulkScanProgress(bulkScan); + + try (MockedStatic mockedHttpClient = mockStatic(HttpClient.class)) { + mockedHttpClient.when(HttpClient::newHttpClient).thenReturn(httpClient); + when(httpClient.send(any(HttpRequest.class), any())) + .thenThrow(new IOException("Network error")); + + // Execute - should not throw + assertDoesNotThrow(() -> progressMonitor.stopMonitoringAndFinalizeBulkScan("test-id")); + } + + verify(persistenceProvider).updateBulkScan(any()); + } + + @Test + void testNotifyInterruptedException() throws Exception { + // Setup + BulkScan bulkScan = createTestBulkScan("test-id", "test-scan"); + bulkScan.setNotifyUrl("http://example.com/notify"); + + progressMonitor.startMonitoringBulkScanProgress(bulkScan); + + try (MockedStatic mockedHttpClient = mockStatic(HttpClient.class)) { + mockedHttpClient.when(HttpClient::newHttpClient).thenReturn(httpClient); + when(httpClient.send(any(HttpRequest.class), any())) + .thenThrow(new InterruptedException("Interrupted")); + + // Execute - should not throw + assertDoesNotThrow(() -> progressMonitor.stopMonitoringAndFinalizeBulkScan("test-id")); + + // Verify thread interrupt flag is set + assertTrue(Thread.currentThread().isInterrupted()); + // Clear interrupt flag + Thread.interrupted(); + } + + verify(persistenceProvider).updateBulkScan(any()); + } + + @Test + void testBulkScanMonitorConsumeDoneNotification() throws Exception { + // Setup + BulkScan bulkScan = createTestBulkScan("test-id", "test-scan"); + bulkScan.setStartTime(System.currentTimeMillis()); + bulkScan.setScanJobsPublished(10); + + progressMonitor.startMonitoringBulkScanProgress(bulkScan); + + // Get the registered consumer + ArgumentCaptor consumerCaptor = + ArgumentCaptor.forClass(DoneNotificationConsumer.class); + verify(orchestrationProvider) + .registerDoneNotificationConsumer(eq(bulkScan), consumerCaptor.capture()); + DoneNotificationConsumer consumer = consumerCaptor.getValue(); + + // Create scan job + ScanJobDescription scanJob = createTestScanJob(bulkScan, JobStatus.SUCCESS); + + // Consume notification + consumer.consumeDoneNotification("test-tag", scanJob); + + // Should not finalize yet (1 of 10 done) + verify(persistenceProvider, never()).updateBulkScan(any()); + } + + @Test + void testBulkScanMonitorCompleteAllJobs() throws Exception { + // Setup + BulkScan bulkScan = createTestBulkScan("test-id", "test-scan"); + bulkScan.setStartTime(System.currentTimeMillis()); + bulkScan.setScanJobsPublished(2); + + progressMonitor.startMonitoringBulkScanProgress(bulkScan); + + // Get the registered consumer + ArgumentCaptor consumerCaptor = + ArgumentCaptor.forClass(DoneNotificationConsumer.class); + verify(orchestrationProvider) + .registerDoneNotificationConsumer(eq(bulkScan), consumerCaptor.capture()); + DoneNotificationConsumer consumer = consumerCaptor.getValue(); + + // Create and consume first job + ScanJobDescription scanJob1 = createTestScanJob(bulkScan, JobStatus.SUCCESS); + consumer.consumeDoneNotification("test-tag", scanJob1); + + // Create and consume second job (should trigger completion) + ScanJobDescription scanJob2 = createTestScanJob(bulkScan, JobStatus.CANCELLED); + consumer.consumeDoneNotification("test-tag", scanJob2); + + // Should finalize bulk scan + verify(persistenceProvider) + .updateBulkScan( + argThat( + scan -> { + return scan.isFinished() && scan.getSuccessfulScans() == 1; + })); + } + + @Test + void testBulkScanMonitorWithException() throws Exception { + // Setup + BulkScan bulkScan = createTestBulkScan("test-id", "test-scan"); + bulkScan.setStartTime(System.currentTimeMillis()); + bulkScan.setScanJobsPublished(1); + + progressMonitor.startMonitoringBulkScanProgress(bulkScan); + + // Get the registered consumer + ArgumentCaptor consumerCaptor = + ArgumentCaptor.forClass(DoneNotificationConsumer.class); + verify(orchestrationProvider) + .registerDoneNotificationConsumer(eq(bulkScan), consumerCaptor.capture()); + DoneNotificationConsumer consumer = consumerCaptor.getValue(); + + // Create scan job with null bulk scan info to trigger exception + ScanTarget target = new ScanTarget(); + target.setHostname("example.com"); + ScanJobDescription scanJob = + new ScanJobDescription( + target, null, "test-db", "test-collection", JobStatus.SUCCESS); + + // Should not throw + assertDoesNotThrow(() -> consumer.consumeDoneNotification("test-tag", scanJob)); + } + + @Test + void testFormatTime() throws Exception { + // Get access to private formatTime method + Method formatTimeMethod = null; + for (Class innerClass : ProgressMonitor.class.getDeclaredClasses()) { + if (innerClass.getSimpleName().equals("BulkscanMonitor")) { + formatTimeMethod = innerClass.getDeclaredMethod("formatTime", double.class); + formatTimeMethod.setAccessible(true); + break; + } + } + assertNotNull(formatTimeMethod); + + // Create instance of inner class + BulkScan bulkScan = createTestBulkScan("test-id", "test-scan"); + progressMonitor.startMonitoringBulkScanProgress(bulkScan); + + ArgumentCaptor consumerCaptor = + ArgumentCaptor.forClass(DoneNotificationConsumer.class); + verify(orchestrationProvider) + .registerDoneNotificationConsumer(eq(bulkScan), consumerCaptor.capture()); + Object bulkScanMonitor = consumerCaptor.getValue(); + + // Test different time formats + assertEquals(" 500 ms", formatTimeMethod.invoke(bulkScanMonitor, 500.0)); + assertEquals(" 5.00 s", formatTimeMethod.invoke(bulkScanMonitor, 5000.0)); + assertEquals("90.00 s", formatTimeMethod.invoke(bulkScanMonitor, 90000.0)); + assertEquals(" 3 h 30 m", formatTimeMethod.invoke(bulkScanMonitor, 9000000.0)); + assertEquals("2.1 d", formatTimeMethod.invoke(bulkScanMonitor, 180000000.0)); + } + + @Test + void testBulkScanMonitorUsesTargetsGivenWhenNoJobsPublished() throws Exception { + // Setup + BulkScan bulkScan = createTestBulkScan("test-id", "test-scan"); + bulkScan.setStartTime(System.currentTimeMillis()); + bulkScan.setScanJobsPublished(0); // No jobs published + bulkScan.setTargetsGiven(5); // But targets given + + progressMonitor.startMonitoringBulkScanProgress(bulkScan); + + // Get the registered consumer + ArgumentCaptor consumerCaptor = + ArgumentCaptor.forClass(DoneNotificationConsumer.class); + verify(orchestrationProvider) + .registerDoneNotificationConsumer(eq(bulkScan), consumerCaptor.capture()); + DoneNotificationConsumer consumer = consumerCaptor.getValue(); + + // Consume notifications for all 5 targets + for (int i = 0; i < 5; i++) { + ScanJobDescription scanJob = createTestScanJob(bulkScan, JobStatus.SUCCESS); + consumer.consumeDoneNotification("test-tag", scanJob); + } + + // Should finalize after 5 jobs (using targetsGiven) + verify(persistenceProvider).updateBulkScan(any()); + } + + @Test + void testDifferentJobStatuses() throws Exception { + // Setup + BulkScan bulkScan = createTestBulkScan("test-id", "test-scan"); + bulkScan.setStartTime(System.currentTimeMillis()); + bulkScan.setScanJobsPublished(6); + + progressMonitor.startMonitoringBulkScanProgress(bulkScan); + + // Get the registered consumer + ArgumentCaptor consumerCaptor = + ArgumentCaptor.forClass(DoneNotificationConsumer.class); + verify(orchestrationProvider) + .registerDoneNotificationConsumer(eq(bulkScan), consumerCaptor.capture()); + DoneNotificationConsumer consumer = consumerCaptor.getValue(); + + // Test different job statuses + JobStatus[] statuses = { + JobStatus.SUCCESS, + JobStatus.EMPTY, + JobStatus.CANCELLED, + JobStatus.ERROR, + JobStatus.SERIALIZATION_ERROR, + JobStatus.INTERNAL_ERROR + }; + + for (JobStatus status : statuses) { + ScanJobDescription scanJob = createTestScanJob(bulkScan, status); + consumer.consumeDoneNotification("test-tag", scanJob); + } + + // Verify the bulk scan has correct counters + ArgumentCaptor bulkScanCaptor = ArgumentCaptor.forClass(BulkScan.class); + verify(persistenceProvider).updateBulkScan(bulkScanCaptor.capture()); + + BulkScan finalizedBulkScan = bulkScanCaptor.getValue(); + assertEquals(1, finalizedBulkScan.getSuccessfulScans()); + assertNotNull(finalizedBulkScan.getJobStatusCounters()); + } +} diff --git a/src/test/java/de/rub/nds/crawler/core/SchedulerListenerShutdownTest.java b/src/test/java/de/rub/nds/crawler/core/SchedulerListenerShutdownTest.java new file mode 100644 index 0000000..4eb80e3 --- /dev/null +++ b/src/test/java/de/rub/nds/crawler/core/SchedulerListenerShutdownTest.java @@ -0,0 +1,118 @@ +/* + * TLS-Crawler - A TLS scanning tool to perform large scale scans with the TLS-Scanner + * + * Copyright 2018-2023 Ruhr University Bochum, Paderborn University, and Hackmanit GmbH + * + * Licensed under Apache License, Version 2.0 + * http://www.apache.org/licenses/LICENSE-2.0.txt + */ +package de.rub.nds.crawler.core; + +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; + +import java.util.HashSet; +import java.util.Set; +import org.junit.jupiter.api.*; +import org.mockito.*; +import org.quartz.*; +import org.quartz.impl.matchers.GroupMatcher; + +class SchedulerListenerShutdownTest { + + @Mock private Scheduler mockScheduler; + @Mock private Trigger mockTrigger; + @Mock private TriggerKey mockTriggerKey; + @Mock private JobDetail mockJobDetail; + @Mock private JobKey mockJobKey; + + private SchedulerListenerShutdown listener; + + @BeforeEach + void setUp() throws SchedulerException { + MockitoAnnotations.openMocks(this); + listener = new SchedulerListenerShutdown(mockScheduler); + + // Default behavior - no triggers + Set emptySet = new HashSet<>(); + when(mockScheduler.getTriggerKeys(any(GroupMatcher.class))).thenReturn(emptySet); + } + + @Test + void testJobScheduled() throws SchedulerException { + // When + listener.jobScheduled(mockTrigger); + + // Then - should check if scheduler should shutdown + verify(mockScheduler).getTriggerKeys(any(GroupMatcher.class)); + } + + @Test + void testJobUnscheduled() throws SchedulerException { + // When + listener.jobUnscheduled(mockTriggerKey); + + // Then - should check if scheduler should shutdown + verify(mockScheduler).getTriggerKeys(any(GroupMatcher.class)); + } + + @Test + void testTriggerFinalized() throws SchedulerException { + // When + listener.triggerFinalized(mockTrigger); + + // Then - should check if scheduler should shutdown + verify(mockScheduler).getTriggerKeys(any(GroupMatcher.class)); + } + + @Test + void testTriggerFinalizedCausesShutdown() throws SchedulerException { + // Given - no active triggers + Set emptySet = new HashSet<>(); + when(mockScheduler.getTriggerKeys(any(GroupMatcher.class))).thenReturn(emptySet); + + // When + listener.triggerFinalized(mockTrigger); + + // Then - should shutdown + verify(mockScheduler).shutdown(); + } + + @Test + void testEmptyMethodsDontTriggerShutdown() throws SchedulerException { + // Test all the empty methods + listener.triggerPaused(mockTriggerKey); + listener.triggersPaused("group"); + listener.triggerResumed(mockTriggerKey); + listener.triggersResumed("group"); + listener.jobAdded(mockJobDetail); + listener.jobDeleted(mockJobKey); + listener.jobPaused(mockJobKey); + listener.jobsPaused("group"); + listener.jobResumed(mockJobKey); + listener.jobsResumed("group"); + listener.schedulerError("error", new SchedulerException()); + listener.schedulerInStandbyMode(); + listener.schedulerStarted(); + listener.schedulerStarting(); + listener.schedulerShutdown(); + listener.schedulerShuttingdown(); + listener.schedulingDataCleared(); + + // Then - none of these should trigger shutdown check + verify(mockScheduler, never()).getTriggerKeys(any(GroupMatcher.class)); + verify(mockScheduler, never()).shutdown(); + } + + @Test + void testConstructor() throws SchedulerException { + // Test that constructor properly stores scheduler reference + SchedulerListenerShutdown newListener = new SchedulerListenerShutdown(mockScheduler); + + // When - trigger an event that uses the scheduler + newListener.jobScheduled(mockTrigger); + + // Then - verify it uses the provided scheduler + verify(mockScheduler).getTriggerKeys(any(GroupMatcher.class)); + } +} diff --git a/src/test/java/de/rub/nds/crawler/core/WorkerTest.java b/src/test/java/de/rub/nds/crawler/core/WorkerTest.java new file mode 100644 index 0000000..77d5c3f --- /dev/null +++ b/src/test/java/de/rub/nds/crawler/core/WorkerTest.java @@ -0,0 +1,468 @@ +/* + * TLS-Crawler - A TLS scanning tool to perform large scale scans with the TLS-Scanner + * + * Copyright 2018-2023 Ruhr University Bochum, Paderborn University, and Hackmanit GmbH + * + * Licensed under Apache License, Version 2.0 + * http://www.apache.org/licenses/LICENSE-2.0.txt + */ +package de.rub.nds.crawler.core; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; + +import de.rub.nds.crawler.config.WorkerCommandConfig; +import de.rub.nds.crawler.constant.JobStatus; +import de.rub.nds.crawler.data.*; +import de.rub.nds.crawler.orchestration.IOrchestrationProvider; +import de.rub.nds.crawler.orchestration.ScanJobConsumer; +import de.rub.nds.crawler.persistence.IPersistenceProvider; +import java.util.concurrent.*; +import org.bson.Document; +import org.junit.jupiter.api.*; +import org.mockito.*; + +class WorkerTest { + + @Mock private WorkerCommandConfig commandConfig; + @Mock private IOrchestrationProvider orchestrationProvider; + @Mock private IPersistenceProvider persistenceProvider; + @Mock private Future resultFuture; + + private Worker worker; + private ArgumentCaptor consumerCaptor; + private ArgumentCaptor scanResultCaptor; + private ArgumentCaptor jobDescriptionCaptor; + + @BeforeEach + void setUp() { + MockitoAnnotations.openMocks(this); + + // Set up command config + when(commandConfig.getParallelScanThreads()).thenReturn(2); + when(commandConfig.getParallelConnectionThreads()).thenReturn(10); + when(commandConfig.getScanTimeout()).thenReturn(60000); // 60 seconds + + // Initialize captors + consumerCaptor = ArgumentCaptor.forClass(ScanJobConsumer.class); + scanResultCaptor = ArgumentCaptor.forClass(ScanResult.class); + jobDescriptionCaptor = ArgumentCaptor.forClass(ScanJobDescription.class); + + worker = new Worker(commandConfig, orchestrationProvider, persistenceProvider); + } + + @AfterEach + void tearDown() { + // Clean up any resources if needed + } + + @Test + void testConstructor() { + assertNotNull(worker); + // Verify that config values were read + verify(commandConfig).getParallelScanThreads(); + verify(commandConfig).getParallelConnectionThreads(); + verify(commandConfig).getScanTimeout(); + } + + @Test + void testStart() { + // When + worker.start(); + + // Then + verify(orchestrationProvider).registerScanJobConsumer(consumerCaptor.capture(), eq(2)); + assertNotNull(consumerCaptor.getValue()); + } + + @Test + void testHandleScanJobSuccess() throws Exception { + // Given + ScanTarget scanTarget = new ScanTarget(); + scanTarget.setHostname("example.com"); + scanTarget.setPort(443); + scanTarget.setIp("192.0.2.1"); + BulkScan bulkScan = createMockBulkScan(); + ScanJobDescription jobDescription = + new ScanJobDescription(scanTarget, bulkScan, JobStatus.TO_BE_EXECUTED); + + Document resultDocument = new Document("result", "success"); + + // Mock the static method call to BulkScanWorkerManager + try (MockedStatic mockedStatic = + mockStatic(BulkScanWorkerManager.class)) { + mockedStatic + .when( + () -> + BulkScanWorkerManager.handleStatic( + any(ScanJobDescription.class), anyInt(), anyInt())) + .thenReturn(resultFuture); + + when(resultFuture.get(anyLong(), any(TimeUnit.class))).thenReturn(resultDocument); + + // Start worker and capture the consumer + worker.start(); + verify(orchestrationProvider) + .registerScanJobConsumer(consumerCaptor.capture(), anyInt()); + ScanJobConsumer consumer = consumerCaptor.getValue(); + + // When - simulate receiving a scan job + consumer.consumeScanJob(jobDescription); + + // Wait for async processing + Thread.sleep(100); + + // Then + verify(persistenceProvider) + .insertScanResult(scanResultCaptor.capture(), jobDescriptionCaptor.capture()); + + ScanResult capturedResult = scanResultCaptor.getValue(); + assertEquals(JobStatus.SUCCESS, capturedResult.getResultStatus()); + assertEquals(resultDocument, capturedResult.getResult()); + + verify(orchestrationProvider).notifyOfDoneScanJob(jobDescription); + } + } + + @Test + void testHandleScanJobTimeout() throws Exception { + // Given + ScanTarget scanTarget = new ScanTarget(); + scanTarget.setHostname("example.com"); + scanTarget.setPort(443); + scanTarget.setIp("192.0.2.1"); + BulkScan bulkScan = createMockBulkScan(); + ScanJobDescription jobDescription = + new ScanJobDescription(scanTarget, bulkScan, JobStatus.TO_BE_EXECUTED); + + // Mock the static method call + try (MockedStatic mockedStatic = + mockStatic(BulkScanWorkerManager.class)) { + mockedStatic + .when( + () -> + BulkScanWorkerManager.handleStatic( + any(ScanJobDescription.class), anyInt(), anyInt())) + .thenReturn(resultFuture); + + // First call throws TimeoutException + when(resultFuture.get(anyLong(), any(TimeUnit.class))) + .thenThrow(new TimeoutException("Scan timeout")) + .thenReturn(null); + + // Start worker and capture the consumer + worker.start(); + verify(orchestrationProvider) + .registerScanJobConsumer(consumerCaptor.capture(), anyInt()); + ScanJobConsumer consumer = consumerCaptor.getValue(); + + // When + consumer.consumeScanJob(jobDescription); + + // Wait for async processing + Thread.sleep(100); + + // Then + verify(resultFuture).cancel(true); + verify(persistenceProvider) + .insertScanResult(scanResultCaptor.capture(), jobDescriptionCaptor.capture()); + + ScanResult capturedResult = scanResultCaptor.getValue(); + assertEquals(JobStatus.CANCELLED, capturedResult.getResultStatus()); + + verify(orchestrationProvider).notifyOfDoneScanJob(jobDescription); + } + } + + @Test + void testHandleScanJobExecutionException() throws Exception { + // Given + ScanTarget scanTarget = new ScanTarget(); + scanTarget.setHostname("example.com"); + scanTarget.setPort(443); + scanTarget.setIp("192.0.2.1"); + BulkScan bulkScan = createMockBulkScan(); + ScanJobDescription jobDescription = + new ScanJobDescription(scanTarget, bulkScan, JobStatus.TO_BE_EXECUTED); + + Exception cause = new RuntimeException("Scan failed"); + + // Mock the static method call + try (MockedStatic mockedStatic = + mockStatic(BulkScanWorkerManager.class)) { + mockedStatic + .when( + () -> + BulkScanWorkerManager.handleStatic( + any(ScanJobDescription.class), anyInt(), anyInt())) + .thenReturn(resultFuture); + + when(resultFuture.get(anyLong(), any(TimeUnit.class))) + .thenThrow(new ExecutionException(cause)); + + // Start worker and capture the consumer + worker.start(); + verify(orchestrationProvider) + .registerScanJobConsumer(consumerCaptor.capture(), anyInt()); + ScanJobConsumer consumer = consumerCaptor.getValue(); + + // When + consumer.consumeScanJob(jobDescription); + + // Wait for async processing + Thread.sleep(100); + + // Then + verify(persistenceProvider) + .insertScanResult(scanResultCaptor.capture(), jobDescriptionCaptor.capture()); + + ScanResult capturedResult = scanResultCaptor.getValue(); + assertEquals(JobStatus.ERROR, capturedResult.getResultStatus()); + assertNotNull(capturedResult.getResult()); + + verify(orchestrationProvider).notifyOfDoneScanJob(jobDescription); + } + } + + @Test + void testHandleScanJobInterruptedException() throws Exception { + // Given + ScanTarget scanTarget = new ScanTarget(); + scanTarget.setHostname("example.com"); + scanTarget.setPort(443); + scanTarget.setIp("192.0.2.1"); + BulkScan bulkScan = createMockBulkScan(); + ScanJobDescription jobDescription = + new ScanJobDescription(scanTarget, bulkScan, JobStatus.TO_BE_EXECUTED); + + // Mock the static method call + try (MockedStatic mockedStatic = + mockStatic(BulkScanWorkerManager.class)) { + mockedStatic + .when( + () -> + BulkScanWorkerManager.handleStatic( + any(ScanJobDescription.class), anyInt(), anyInt())) + .thenReturn(resultFuture); + + when(resultFuture.get(anyLong(), any(TimeUnit.class))) + .thenThrow(new InterruptedException("Worker interrupted")); + + // Start worker and capture the consumer + worker.start(); + verify(orchestrationProvider) + .registerScanJobConsumer(consumerCaptor.capture(), anyInt()); + ScanJobConsumer consumer = consumerCaptor.getValue(); + + // When + consumer.consumeScanJob(jobDescription); + + // Wait for async processing + Thread.sleep(100); + + // Then - should not persist on interrupt + verify(persistenceProvider, never()).insertScanResult(any(), any()); + verify(orchestrationProvider, never()).notifyOfDoneScanJob(any()); + assertEquals(JobStatus.INTERNAL_ERROR, jobDescription.getStatus()); + } + } + + @Test + void testHandleScanJobUnexpectedException() throws Exception { + // Given + ScanTarget scanTarget = new ScanTarget(); + scanTarget.setHostname("example.com"); + scanTarget.setPort(443); + scanTarget.setIp("192.0.2.1"); + BulkScan bulkScan = createMockBulkScan(); + ScanJobDescription jobDescription = + new ScanJobDescription(scanTarget, bulkScan, JobStatus.TO_BE_EXECUTED); + + // Mock the static method call + try (MockedStatic mockedStatic = + mockStatic(BulkScanWorkerManager.class)) { + mockedStatic + .when( + () -> + BulkScanWorkerManager.handleStatic( + any(ScanJobDescription.class), anyInt(), anyInt())) + .thenReturn(resultFuture); + + when(resultFuture.get(anyLong(), any(TimeUnit.class))) + .thenThrow(new RuntimeException("Unexpected error")); + + // Start worker and capture the consumer + worker.start(); + verify(orchestrationProvider) + .registerScanJobConsumer(consumerCaptor.capture(), anyInt()); + ScanJobConsumer consumer = consumerCaptor.getValue(); + + // When + consumer.consumeScanJob(jobDescription); + + // Wait for async processing + Thread.sleep(100); + + // Then + verify(persistenceProvider) + .insertScanResult(scanResultCaptor.capture(), jobDescriptionCaptor.capture()); + + ScanResult capturedResult = scanResultCaptor.getValue(); + assertEquals(JobStatus.CRAWLER_ERROR, capturedResult.getResultStatus()); + + verify(orchestrationProvider).notifyOfDoneScanJob(jobDescription); + } + } + + @Test + void testHandleScanJobNullResult() throws Exception { + // Given + ScanTarget scanTarget = new ScanTarget(); + scanTarget.setHostname("example.com"); + scanTarget.setPort(443); + scanTarget.setIp("192.0.2.1"); + BulkScan bulkScan = createMockBulkScan(); + ScanJobDescription jobDescription = + new ScanJobDescription(scanTarget, bulkScan, JobStatus.TO_BE_EXECUTED); + + // Mock the static method call + try (MockedStatic mockedStatic = + mockStatic(BulkScanWorkerManager.class)) { + mockedStatic + .when( + () -> + BulkScanWorkerManager.handleStatic( + any(ScanJobDescription.class), anyInt(), anyInt())) + .thenReturn(resultFuture); + + when(resultFuture.get(anyLong(), any(TimeUnit.class))).thenReturn(null); + + // Start worker and capture the consumer + worker.start(); + verify(orchestrationProvider) + .registerScanJobConsumer(consumerCaptor.capture(), anyInt()); + ScanJobConsumer consumer = consumerCaptor.getValue(); + + // When + consumer.consumeScanJob(jobDescription); + + // Wait for async processing + Thread.sleep(100); + + // Then + verify(persistenceProvider) + .insertScanResult(scanResultCaptor.capture(), jobDescriptionCaptor.capture()); + + ScanResult capturedResult = scanResultCaptor.getValue(); + assertEquals(JobStatus.EMPTY, capturedResult.getResultStatus()); + + verify(orchestrationProvider).notifyOfDoneScanJob(jobDescription); + } + } + + @Test + void testPersistResultException() throws Exception { + // Given + ScanTarget scanTarget = new ScanTarget(); + scanTarget.setHostname("example.com"); + scanTarget.setPort(443); + scanTarget.setIp("192.0.2.1"); + BulkScan bulkScan = createMockBulkScan(); + ScanJobDescription jobDescription = + new ScanJobDescription(scanTarget, bulkScan, JobStatus.TO_BE_EXECUTED); + + Document resultDocument = new Document("result", "success"); + + // Mock persistence to throw exception + doThrow(new RuntimeException("DB error")) + .when(persistenceProvider) + .insertScanResult(any(), any()); + + // Mock the static method call + try (MockedStatic mockedStatic = + mockStatic(BulkScanWorkerManager.class)) { + mockedStatic + .when( + () -> + BulkScanWorkerManager.handleStatic( + any(ScanJobDescription.class), anyInt(), anyInt())) + .thenReturn(resultFuture); + + when(resultFuture.get(anyLong(), any(TimeUnit.class))).thenReturn(resultDocument); + + // Start worker and capture the consumer + worker.start(); + verify(orchestrationProvider) + .registerScanJobConsumer(consumerCaptor.capture(), anyInt()); + ScanJobConsumer consumer = consumerCaptor.getValue(); + + // When + consumer.consumeScanJob(jobDescription); + + // Wait for async processing + Thread.sleep(100); + + // Then - should still notify even if persist fails + verify(orchestrationProvider).notifyOfDoneScanJob(jobDescription); + assertEquals(JobStatus.INTERNAL_ERROR, jobDescription.getStatus()); + } + } + + @Test + void testTimeoutWithGracefulShutdownFailure() throws Exception { + // Given + ScanTarget scanTarget = new ScanTarget(); + scanTarget.setHostname("example.com"); + scanTarget.setPort(443); + scanTarget.setIp("192.0.2.1"); + BulkScan bulkScan = createMockBulkScan(); + ScanJobDescription jobDescription = + new ScanJobDescription(scanTarget, bulkScan, JobStatus.TO_BE_EXECUTED); + + // Mock the static method call + try (MockedStatic mockedStatic = + mockStatic(BulkScanWorkerManager.class)) { + mockedStatic + .when( + () -> + BulkScanWorkerManager.handleStatic( + any(ScanJobDescription.class), anyInt(), anyInt())) + .thenReturn(resultFuture); + + // First call throws TimeoutException, second call also throws + when(resultFuture.get(anyLong(), any(TimeUnit.class))) + .thenThrow(new TimeoutException("Scan timeout")) + .thenThrow(new TimeoutException("Graceful shutdown failed")); + + // Start worker and capture the consumer + worker.start(); + verify(orchestrationProvider) + .registerScanJobConsumer(consumerCaptor.capture(), anyInt()); + ScanJobConsumer consumer = consumerCaptor.getValue(); + + // When + consumer.consumeScanJob(jobDescription); + + // Wait for async processing + Thread.sleep(100); + + // Then + verify(resultFuture, times(2)).cancel(true); + verify(persistenceProvider) + .insertScanResult(scanResultCaptor.capture(), jobDescriptionCaptor.capture()); + + ScanResult capturedResult = scanResultCaptor.getValue(); + assertEquals(JobStatus.CANCELLED, capturedResult.getResultStatus()); + + verify(orchestrationProvider).notifyOfDoneScanJob(jobDescription); + } + } + + private BulkScan createMockBulkScan() { + BulkScan bulkScan = mock(BulkScan.class); + when(bulkScan.getName()).thenReturn("test-scan"); + when(bulkScan.getCollectionName()).thenReturn("test-collection"); + return bulkScan; + } +} diff --git a/src/test/java/de/rub/nds/crawler/data/ScanJobDescriptionTest.java b/src/test/java/de/rub/nds/crawler/data/ScanJobDescriptionTest.java new file mode 100644 index 0000000..54560eb --- /dev/null +++ b/src/test/java/de/rub/nds/crawler/data/ScanJobDescriptionTest.java @@ -0,0 +1,177 @@ +/* + * TLS-Crawler - A TLS scanning tool to perform large scale scans with the TLS-Scanner + * + * Copyright 2018-2023 Ruhr University Bochum, Paderborn University, and Hackmanit GmbH + * + * Licensed under Apache License, Version 2.0 + * http://www.apache.org/licenses/LICENSE-2.0.txt + */ +package de.rub.nds.crawler.data; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +import de.rub.nds.crawler.constant.JobStatus; +import java.io.*; +import org.junit.jupiter.api.*; +import org.mockito.*; + +class ScanJobDescriptionTest { + + @Mock private BulkScan mockBulkScan; + @Mock private BulkScanInfo mockBulkScanInfo; + + private ScanTarget scanTarget; + + @BeforeEach + void setUp() { + MockitoAnnotations.openMocks(this); + + scanTarget = new ScanTarget(); + scanTarget.setHostname("example.com"); + scanTarget.setPort(443); + + when(mockBulkScan.getName()).thenReturn("test-scan"); + when(mockBulkScan.getCollectionName()).thenReturn("test-collection"); + } + + @Test + void testConstructorWithBulkScanInfo() { + // When + ScanJobDescription job = + new ScanJobDescription( + scanTarget, + mockBulkScanInfo, + "dbName", + "collectionName", + JobStatus.TO_BE_EXECUTED); + + // Then + assertEquals(scanTarget, job.getScanTarget()); + assertEquals(mockBulkScanInfo, job.getBulkScanInfo()); + assertEquals("dbName", job.getDbName()); + assertEquals("collectionName", job.getCollectionName()); + assertEquals(JobStatus.TO_BE_EXECUTED, job.getStatus()); + } + + @Test + void testConstructorWithBulkScan() { + // When + ScanJobDescription job = + new ScanJobDescription(scanTarget, mockBulkScan, JobStatus.TO_BE_EXECUTED); + + // Then + assertEquals(scanTarget, job.getScanTarget()); + assertNotNull(job.getBulkScanInfo()); + assertEquals("test-scan", job.getDbName()); + assertEquals("test-collection", job.getCollectionName()); + assertEquals(JobStatus.TO_BE_EXECUTED, job.getStatus()); + } + + @Test + void testSetStatus() { + // Given + ScanJobDescription job = + new ScanJobDescription(scanTarget, mockBulkScan, JobStatus.TO_BE_EXECUTED); + + // When + job.setStatus(JobStatus.SUCCESS); + + // Then + assertEquals(JobStatus.SUCCESS, job.getStatus()); + } + + @Test + void testSetDeliveryTag() { + // Given + ScanJobDescription job = + new ScanJobDescription(scanTarget, mockBulkScan, JobStatus.TO_BE_EXECUTED); + + // When + job.setDeliveryTag(123L); + + // Then + assertEquals(123L, job.getDeliveryTag()); + } + + @Test + void testSetDeliveryTagTwiceThrowsException() { + // Given + ScanJobDescription job = + new ScanJobDescription(scanTarget, mockBulkScan, JobStatus.TO_BE_EXECUTED); + job.setDeliveryTag(123L); + + // When/Then + assertThrows(IllegalStateException.class, () -> job.setDeliveryTag(456L)); + } + + @Test + void testGetDeliveryTagBeforeSettingThrowsException() { + // Given + ScanJobDescription job = + new ScanJobDescription(scanTarget, mockBulkScan, JobStatus.TO_BE_EXECUTED); + + // When/Then + assertThrows(Exception.class, job::getDeliveryTag); + } + + @Test + void testSerialization() throws Exception { + // Given + ScanJobDescription job = + new ScanJobDescription(scanTarget, mockBulkScan, JobStatus.TO_BE_EXECUTED); + job.setDeliveryTag(123L); + + // When - serialize + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + ObjectOutputStream oos = new ObjectOutputStream(baos); + oos.writeObject(job); + oos.close(); + + // And deserialize + ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray()); + ObjectInputStream ois = new ObjectInputStream(bais); + ScanJobDescription deserialized = (ScanJobDescription) ois.readObject(); + ois.close(); + + // Then + assertEquals(job.getScanTarget().getHostname(), deserialized.getScanTarget().getHostname()); + assertEquals(job.getDbName(), deserialized.getDbName()); + assertEquals(job.getCollectionName(), deserialized.getCollectionName()); + assertEquals(job.getStatus(), deserialized.getStatus()); + + // Delivery tag should not be serialized (transient) + assertThrows(Exception.class, deserialized::getDeliveryTag); + } + + @Test + void testAllJobStatuses() { + // Test that all job statuses can be set + JobStatus[] statuses = JobStatus.values(); + + for (JobStatus status : statuses) { + ScanJobDescription job = new ScanJobDescription(scanTarget, mockBulkScan, status); + assertEquals(status, job.getStatus()); + } + } + + @Test + void testGettersReturnCorrectValues() { + // Given + String dbName = "test-db"; + String collectionName = "test-collection"; + JobStatus status = JobStatus.TO_BE_EXECUTED; + + // When + ScanJobDescription job = + new ScanJobDescription( + scanTarget, mockBulkScanInfo, dbName, collectionName, status); + + // Then + assertEquals(scanTarget, job.getScanTarget()); + assertEquals(mockBulkScanInfo, job.getBulkScanInfo()); + assertEquals(dbName, job.getDbName()); + assertEquals(collectionName, job.getCollectionName()); + assertEquals(status, job.getStatus()); + } +} diff --git a/src/test/java/de/rub/nds/crawler/data/ScanResultTest.java b/src/test/java/de/rub/nds/crawler/data/ScanResultTest.java new file mode 100644 index 0000000..dda18c0 --- /dev/null +++ b/src/test/java/de/rub/nds/crawler/data/ScanResultTest.java @@ -0,0 +1,178 @@ +/* + * TLS-Crawler - A TLS scanning tool to perform large scale scans with the TLS-Scanner + * + * Copyright 2018-2023 Ruhr University Bochum, Paderborn University, and Hackmanit GmbH + * + * Licensed under Apache License, Version 2.0 + * http://www.apache.org/licenses/LICENSE-2.0.txt + */ +package de.rub.nds.crawler.data; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.Mockito.*; + +import de.rub.nds.crawler.constant.JobStatus; +import org.bson.Document; +import org.junit.jupiter.api.*; +import org.mockito.*; + +class ScanResultTest { + + @Mock private ScanJobDescription mockScanJobDescription; + @Mock private BulkScanInfo mockBulkScanInfo; + + private ScanTarget scanTarget; + private Document testDocument; + + @BeforeEach + void setUp() { + MockitoAnnotations.openMocks(this); + scanTarget = new ScanTarget(); + scanTarget.setHostname("example.com"); + scanTarget.setPort(443); + scanTarget.setIp("192.0.2.1"); + testDocument = new Document("test", "data"); + + when(mockScanJobDescription.getBulkScanInfo()).thenReturn(mockBulkScanInfo); + when(mockScanJobDescription.getScanTarget()).thenReturn(scanTarget); + when(mockBulkScanInfo.getBulkScanId()).thenReturn("bulk-scan-123"); + } + + @Test + void testConstructorWithScanJobDescription() { + // Given + when(mockScanJobDescription.getStatus()).thenReturn(JobStatus.SUCCESS); + + // When + ScanResult result = new ScanResult(mockScanJobDescription, testDocument); + + // Then + assertNotNull(result.getId()); + assertEquals("bulk-scan-123", result.getBulkScan()); + assertEquals(scanTarget, result.getScanTarget()); + assertEquals(JobStatus.SUCCESS, result.getResultStatus()); + assertEquals(testDocument, result.getResult()); + } + + @Test + void testConstructorWithScanJobDescriptionThrowsOnToBeExecuted() { + // Given + when(mockScanJobDescription.getStatus()).thenReturn(JobStatus.TO_BE_EXECUTED); + + // When/Then + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> new ScanResult(mockScanJobDescription, testDocument)); + assertEquals( + "ScanJobDescription must not be in TO_BE_EXECUTED state", exception.getMessage()); + } + + @Test + void testFromExceptionWithErrorStatus() { + // Given + when(mockScanJobDescription.getStatus()).thenReturn(JobStatus.ERROR); + Exception testException = new RuntimeException("Test error"); + + // When + ScanResult result = ScanResult.fromException(mockScanJobDescription, testException); + + // Then + assertNotNull(result.getId()); + assertEquals("bulk-scan-123", result.getBulkScan()); + assertEquals(scanTarget, result.getScanTarget()); + assertEquals(JobStatus.ERROR, result.getResultStatus()); + assertNotNull(result.getResult()); + assertEquals(testException, result.getResult().get("exception")); + } + + @Test + void testFromExceptionWithCancelledStatus() { + // Given + when(mockScanJobDescription.getStatus()).thenReturn(JobStatus.CANCELLED); + Exception testException = new RuntimeException("Cancelled"); + + // When + ScanResult result = ScanResult.fromException(mockScanJobDescription, testException); + + // Then + assertEquals(JobStatus.CANCELLED, result.getResultStatus()); + assertEquals(testException, result.getResult().get("exception")); + } + + @Test + void testFromExceptionThrowsOnNonErrorStatus() { + // Given + when(mockScanJobDescription.getStatus()).thenReturn(JobStatus.SUCCESS); + Exception testException = new RuntimeException("Test error"); + + // When/Then + IllegalArgumentException exception = + assertThrows( + IllegalArgumentException.class, + () -> ScanResult.fromException(mockScanJobDescription, testException)); + assertEquals("ScanJobDescription must be in an error state", exception.getMessage()); + } + + @Test + void testSetId() { + // Given + when(mockScanJobDescription.getStatus()).thenReturn(JobStatus.SUCCESS); + ScanResult result = new ScanResult(mockScanJobDescription, testDocument); + String newId = "new-id-456"; + + // When + result.setId(newId); + + // Then + assertEquals(newId, result.getId()); + } + + @Test + void testAllGetters() { + // Given + when(mockScanJobDescription.getStatus()).thenReturn(JobStatus.SUCCESS); + + // When + ScanResult result = new ScanResult(mockScanJobDescription, testDocument); + + // Then + assertNotNull(result.getId()); + assertTrue(result.getId().matches("[a-f0-9\\-]{36}")); // UUID format + assertEquals("bulk-scan-123", result.getBulkScan()); + assertEquals(scanTarget, result.getScanTarget()); + assertEquals(testDocument, result.getResult()); + assertEquals(JobStatus.SUCCESS, result.getResultStatus()); + } + + @Test + void testWithNullDocument() { + // Given + when(mockScanJobDescription.getStatus()).thenReturn(JobStatus.EMPTY); + + // When + ScanResult result = new ScanResult(mockScanJobDescription, null); + + // Then + assertNull(result.getResult()); + assertEquals(JobStatus.EMPTY, result.getResultStatus()); + } + + @Test + void testAllErrorStatuses() { + // Test that all error statuses work with fromException + JobStatus[] errorStatuses = { + JobStatus.ERROR, JobStatus.CANCELLED, JobStatus.INTERNAL_ERROR, JobStatus.CRAWLER_ERROR + }; + + for (JobStatus status : errorStatuses) { + when(mockScanJobDescription.getStatus()).thenReturn(status); + Exception testException = new RuntimeException("Test error for " + status); + + ScanResult result = ScanResult.fromException(mockScanJobDescription, testException); + + assertEquals(status, result.getResultStatus()); + assertEquals(testException, result.getResult().get("exception")); + } + } +} diff --git a/src/test/java/de/rub/nds/crawler/data/ScanTargetTest.java b/src/test/java/de/rub/nds/crawler/data/ScanTargetTest.java new file mode 100644 index 0000000..4e5931b --- /dev/null +++ b/src/test/java/de/rub/nds/crawler/data/ScanTargetTest.java @@ -0,0 +1,282 @@ +/* + * TLS-Crawler - A TLS scanning tool to perform large scale scans with the TLS-Scanner + * + * Copyright 2018-2023 Ruhr University Bochum, Paderborn University, and Hackmanit GmbH + * + * Licensed under Apache License, Version 2.0 + * http://www.apache.org/licenses/LICENSE-2.0.txt + */ +package de.rub.nds.crawler.data; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; + +import de.rub.nds.crawler.constant.JobStatus; +import de.rub.nds.crawler.denylist.IDenylistProvider; +import java.net.InetAddress; +import org.apache.commons.lang3.tuple.Pair; +import org.junit.jupiter.api.*; +import org.mockito.*; + +class ScanTargetTest { + + @Mock private IDenylistProvider mockDenylistProvider; + @Mock private InetAddress mockInetAddress; + + @BeforeEach + void setUp() { + MockitoAnnotations.openMocks(this); + } + + @Test + void testDefaultConstructor() { + ScanTarget target = new ScanTarget(); + assertNull(target.getIp()); + assertNull(target.getHostname()); + assertEquals(0, target.getPort()); + assertEquals(0, target.getTrancoRank()); + } + + @Test + void testSettersAndGetters() { + ScanTarget target = new ScanTarget(); + + target.setIp("192.0.2.1"); + target.setHostname("example.com"); + target.setPort(8443); + target.setTrancoRank(100); + + assertEquals("192.0.2.1", target.getIp()); + assertEquals("example.com", target.getHostname()); + assertEquals(8443, target.getPort()); + assertEquals(100, target.getTrancoRank()); + } + + @Test + void testToStringWithHostname() { + ScanTarget target = new ScanTarget(); + target.setHostname("example.com"); + target.setIp("192.0.2.1"); + + assertEquals("example.com", target.toString()); + } + + @Test + void testToStringWithoutHostname() { + ScanTarget target = new ScanTarget(); + target.setIp("192.0.2.1"); + + assertEquals("192.0.2.1", target.toString()); + } + + @Test + void testFromTargetStringWithSimpleHostname() { + // When + Pair result = ScanTarget.fromTargetString("example.com", 443, null); + + // Then + ScanTarget target = result.getLeft(); + assertEquals("example.com", target.getHostname()); + assertEquals(443, target.getPort()); + assertEquals(JobStatus.TO_BE_EXECUTED, result.getRight()); + } + + @Test + void testFromTargetStringWithIP() { + // When + Pair result = ScanTarget.fromTargetString("192.0.2.1", 443, null); + + // Then + ScanTarget target = result.getLeft(); + assertEquals("192.0.2.1", target.getIp()); + assertNull(target.getHostname()); + assertEquals(443, target.getPort()); + assertEquals(JobStatus.TO_BE_EXECUTED, result.getRight()); + } + + @Test + void testFromTargetStringWithPort() { + // When + Pair result = + ScanTarget.fromTargetString("example.com:8443", 443, null); + + // Then + ScanTarget target = result.getLeft(); + assertEquals("example.com", target.getHostname()); + assertEquals(8443, target.getPort()); + assertEquals(JobStatus.TO_BE_EXECUTED, result.getRight()); + } + + @Test + void testFromTargetStringWithTrancoRank() { + // When + Pair result = + ScanTarget.fromTargetString("100,example.com", 443, null); + + // Then + ScanTarget target = result.getLeft(); + assertEquals("example.com", target.getHostname()); + assertEquals(443, target.getPort()); + assertEquals(100, target.getTrancoRank()); + assertEquals(JobStatus.TO_BE_EXECUTED, result.getRight()); + } + + @Test + void testFromTargetStringWithTrancoRankAndPort() { + // When + Pair result = + ScanTarget.fromTargetString("100,example.com:8443", 443, null); + + // Then + ScanTarget target = result.getLeft(); + assertEquals("example.com", target.getHostname()); + assertEquals(8443, target.getPort()); + assertEquals(100, target.getTrancoRank()); + assertEquals(JobStatus.TO_BE_EXECUTED, result.getRight()); + } + + @Test + void testFromTargetStringWithInvalidTrancoRank() { + // When + Pair result = + ScanTarget.fromTargetString("abc,example.com", 443, null); + + // Then + ScanTarget target = result.getLeft(); + assertEquals("", target.getHostname()); + assertEquals(443, target.getPort()); + assertEquals(0, target.getTrancoRank()); + } + + @Test + void testFromTargetStringWithMxFormat() { + // When + Pair result = + ScanTarget.fromTargetString("mx://example.com", 443, null); + + // Then + ScanTarget target = result.getLeft(); + assertEquals("example.com", target.getHostname()); + assertEquals(443, target.getPort()); + assertEquals(JobStatus.TO_BE_EXECUTED, result.getRight()); + } + + @Test + void testFromTargetStringWithQuotes() { + // When + Pair result = + ScanTarget.fromTargetString("\"example.com\"", 443, null); + + // Then + ScanTarget target = result.getLeft(); + assertEquals("example.com", target.getHostname()); + assertEquals(443, target.getPort()); + assertEquals(JobStatus.TO_BE_EXECUTED, result.getRight()); + } + + @Test + void testFromTargetStringWithInvalidPort() { + // When - port 1 is considered invalid + Pair result1 = + ScanTarget.fromTargetString("example.com:1", 443, null); + + // Then - port is not set (remains 0) + assertEquals(0, result1.getLeft().getPort()); + + // When - port > 65535 is invalid (parseInt will fail) + try { + Pair result2 = + ScanTarget.fromTargetString("example.com:70000", 443, null); + // If it doesn't throw, port should be 0 + assertEquals(0, result2.getLeft().getPort()); + } catch (NumberFormatException e) { + // This is expected for port > 65535 + } + } + + @Test + void testFromTargetStringWithValidPortBoundaries() { + // Test port 2 (lowest valid) + Pair result1 = + ScanTarget.fromTargetString("example.com:2", 443, null); + assertEquals(2, result1.getLeft().getPort()); + + // Test port 65534 (highest valid < 65535) + Pair result2 = + ScanTarget.fromTargetString("example.com:65534", 443, null); + assertEquals(65534, result2.getLeft().getPort()); + } + + @Test + void testFromTargetStringUnknownHost() { + // This test would normally require mocking static InetAddress.getByName() + // Since we're testing the real behavior, we'll use a hostname that likely doesn't exist + + // When + Pair result = + ScanTarget.fromTargetString( + "this-host-definitely-does-not-exist-xyz123.com", 443, null); + + // Then + assertEquals(JobStatus.UNRESOLVABLE, result.getRight()); + assertNull(result.getLeft().getIp()); + } + + @Test + void testFromTargetStringDenylisted() { + // Given + when(mockDenylistProvider.isDenylisted(any(ScanTarget.class))).thenReturn(true); + + // When + Pair result = + ScanTarget.fromTargetString("192.0.2.1", 443, mockDenylistProvider); + + // Then + assertEquals(JobStatus.DENYLISTED, result.getRight()); + verify(mockDenylistProvider).isDenylisted(any(ScanTarget.class)); + } + + @Test + void testFromTargetStringNotDenylisted() { + // Given + when(mockDenylistProvider.isDenylisted(any(ScanTarget.class))).thenReturn(false); + + // When + Pair result = + ScanTarget.fromTargetString("192.0.2.1", 443, mockDenylistProvider); + + // Then + assertEquals(JobStatus.TO_BE_EXECUTED, result.getRight()); + verify(mockDenylistProvider).isDenylisted(any(ScanTarget.class)); + } + + @Test + void testFromTargetStringComplexScenarios() { + // Test with everything combined + Pair result = + ScanTarget.fromTargetString("50,mx://\"example.com\":8080", 443, null); + + ScanTarget target = result.getLeft(); + assertEquals( + "\"example.com\"", + target.getHostname()); // Quotes are kept when mixed with // processing + assertEquals(8080, target.getPort()); + assertEquals(50, target.getTrancoRank()); + assertEquals( + JobStatus.UNRESOLVABLE, + result.getRight()); // Hostname with quotes can't be resolved + } + + @Test + void testFromTargetStringWithIPv4() { + // Test various IPv4 formats + String[] validIPs = {"192.0.2.1", "10.0.0.1", "172.16.0.1", "255.255.255.255"}; + + for (String ip : validIPs) { + Pair result = ScanTarget.fromTargetString(ip, 443, null); + assertEquals(ip, result.getLeft().getIp()); + assertNull(result.getLeft().getHostname()); + } + } +} diff --git a/src/test/java/de/rub/nds/crawler/denylist/DenylistFileProviderTest.java b/src/test/java/de/rub/nds/crawler/denylist/DenylistFileProviderTest.java new file mode 100644 index 0000000..203aa42 --- /dev/null +++ b/src/test/java/de/rub/nds/crawler/denylist/DenylistFileProviderTest.java @@ -0,0 +1,235 @@ +/* + * TLS-Crawler - A TLS scanning tool to perform large scale scans with the TLS-Scanner + * + * Copyright 2018-2023 Ruhr University Bochum, Paderborn University, and Hackmanit GmbH + * + * Licensed under Apache License, Version 2.0 + * http://www.apache.org/licenses/LICENSE-2.0.txt + */ +package de.rub.nds.crawler.denylist; + +import static org.junit.jupiter.api.Assertions.*; + +import de.rub.nds.crawler.data.ScanTarget; +import java.io.FileWriter; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import org.junit.jupiter.api.*; + +class DenylistFileProviderTest { + + private Path tempFile; + private DenylistFileProvider provider; + + @BeforeEach + void setUp() throws IOException { + tempFile = Files.createTempFile("denylist", ".txt"); + } + + @AfterEach + void tearDown() throws IOException { + if (tempFile != null && Files.exists(tempFile)) { + Files.delete(tempFile); + } + } + + @Test + void testEmptyDenylist() throws IOException { + // Given - empty file + provider = new DenylistFileProvider(tempFile.toString()); + + // When + ScanTarget target = createTarget("example.com", "192.0.2.1"); + + // Then + assertFalse(provider.isDenylisted(target)); + } + + @Test + void testDenylistWithDomain() throws IOException { + // Given + writeDenylist("badsite.com", "evil.org"); + provider = new DenylistFileProvider(tempFile.toString()); + + // Then - denylisted domains + assertTrue(provider.isDenylisted(createTarget("badsite.com", "192.0.2.1"))); + assertTrue(provider.isDenylisted(createTarget("evil.org", "192.0.2.2"))); + + // Not denylisted + assertFalse(provider.isDenylisted(createTarget("goodsite.com", "192.0.2.3"))); + } + + @Test + void testDenylistWithIP() throws IOException { + // Given + writeDenylist("192.0.2.1", "10.0.0.1"); + provider = new DenylistFileProvider(tempFile.toString()); + + // Then - denylisted IPs + assertTrue(provider.isDenylisted(createTarget("example.com", "192.0.2.1"))); + assertTrue(provider.isDenylisted(createTarget("test.com", "10.0.0.1"))); + + // Not denylisted + assertFalse(provider.isDenylisted(createTarget("example.com", "192.0.2.2"))); + } + + @Test + void testDenylistWithCIDR() throws IOException { + // Given + writeDenylist("192.0.2.0/24", "10.0.0.0/16"); + provider = new DenylistFileProvider(tempFile.toString()); + + // Then - IPs in subnet are denylisted + assertTrue(provider.isDenylisted(createTarget("example.com", "192.0.2.1"))); + assertTrue(provider.isDenylisted(createTarget("example.com", "192.0.2.255"))); + assertTrue(provider.isDenylisted(createTarget("example.com", "10.0.1.1"))); + assertTrue(provider.isDenylisted(createTarget("example.com", "10.0.255.255"))); + + // IPs outside subnet are not denylisted + assertFalse(provider.isDenylisted(createTarget("example.com", "192.0.3.1"))); + assertFalse(provider.isDenylisted(createTarget("example.com", "10.1.0.1"))); + } + + @Test + void testDenylistWithMixedEntries() throws IOException { + // Given + writeDenylist("badsite.com", "192.0.2.1", "10.0.0.0/24", "evil.org", "172.16.0.1"); + provider = new DenylistFileProvider(tempFile.toString()); + + // Then - all types work + assertTrue(provider.isDenylisted(createTarget("badsite.com", "1.1.1.1"))); + assertTrue(provider.isDenylisted(createTarget("example.com", "192.0.2.1"))); + assertTrue(provider.isDenylisted(createTarget("example.com", "10.0.0.100"))); + assertTrue(provider.isDenylisted(createTarget("evil.org", "2.2.2.2"))); + assertTrue(provider.isDenylisted(createTarget("example.com", "172.16.0.1"))); + } + + @Test + void testInvalidDenylistEntries() throws IOException { + // Given - invalid entries should be ignored + writeDenylist( + "not-a-valid-domain-or-ip", + "999.999.999.999", // invalid IP + "192.0.2.0/999", // invalid CIDR + "example.com" // valid + ); + provider = new DenylistFileProvider(tempFile.toString()); + + // Then - only valid entry works + assertTrue(provider.isDenylisted(createTarget("example.com", "1.1.1.1"))); + assertFalse(provider.isDenylisted(createTarget("not-a-valid-domain-or-ip", "1.1.1.1"))); + } + + @Test + void testNonExistentFile() { + // Given - file that doesn't exist + provider = new DenylistFileProvider("/path/that/does/not/exist/denylist.txt"); + + // Then - should not crash and nothing is denylisted + assertFalse(provider.isDenylisted(createTarget("example.com", "192.0.2.1"))); + } + + @Test + void testTargetWithNullHostname() throws IOException { + // Given + writeDenylist("192.0.2.1"); + provider = new DenylistFileProvider(tempFile.toString()); + + // When - target with null hostname + ScanTarget target = new ScanTarget(); + target.setIp("192.0.2.1"); + + // Then + assertTrue(provider.isDenylisted(target)); + } + + @Test + void testTargetWithNullIP() throws IOException { + // Given + writeDenylist("example.com"); + provider = new DenylistFileProvider(tempFile.toString()); + + // When - target with null IP + ScanTarget target = new ScanTarget(); + target.setHostname("example.com"); + + // Then + assertTrue(provider.isDenylisted(target)); + } + + @Test + void testCIDRBoundaries() throws IOException { + // Given + writeDenylist("192.0.2.0/30"); // .0, .1, .2, .3 + provider = new DenylistFileProvider(tempFile.toString()); + + // Then + assertTrue(provider.isDenylisted(createTarget("test.com", "192.0.2.0"))); + assertTrue(provider.isDenylisted(createTarget("test.com", "192.0.2.1"))); + assertTrue(provider.isDenylisted(createTarget("test.com", "192.0.2.2"))); + assertTrue(provider.isDenylisted(createTarget("test.com", "192.0.2.3"))); + assertFalse(provider.isDenylisted(createTarget("test.com", "192.0.2.4"))); + } + + @Test + void testConcurrentAccess() throws IOException, InterruptedException { + // Given + writeDenylist("192.0.2.0/24", "badsite.com"); + provider = new DenylistFileProvider(tempFile.toString()); + + // When - multiple threads access isDenylisted + int threadCount = 10; + Thread[] threads = new Thread[threadCount]; + boolean[] results = new boolean[threadCount]; + + for (int i = 0; i < threadCount; i++) { + final int index = i; + threads[i] = + new Thread( + () -> { + ScanTarget target = createTarget("badsite.com", "192.0.2." + index); + results[index] = provider.isDenylisted(target); + }); + threads[i].start(); + } + + // Wait for all threads + for (Thread thread : threads) { + thread.join(); + } + + // Then - all should return true (synchronized method) + for (boolean result : results) { + assertTrue(result); + } + } + + @Test + void testIPv6Handling() throws IOException { + // Given - IPv4 subnet + writeDenylist("192.0.2.0/24"); + provider = new DenylistFileProvider(tempFile.toString()); + + // When - checking IPv6 address against IPv4 subnet + ScanTarget target = createTarget("example.com", "2001:db8::1"); + + // Then - should not crash and return false + assertFalse(provider.isDenylisted(target)); + } + + private void writeDenylist(String... entries) throws IOException { + try (FileWriter writer = new FileWriter(tempFile.toFile())) { + for (String entry : entries) { + writer.write(entry + "\n"); + } + } + } + + private ScanTarget createTarget(String hostname, String ip) { + ScanTarget target = new ScanTarget(); + target.setHostname(hostname); + target.setIp(ip); + return target; + } +} diff --git a/src/test/java/de/rub/nds/crawler/orchestration/RabbitMqOrchestrationProviderTest.java b/src/test/java/de/rub/nds/crawler/orchestration/RabbitMqOrchestrationProviderTest.java new file mode 100644 index 0000000..8e2d5af --- /dev/null +++ b/src/test/java/de/rub/nds/crawler/orchestration/RabbitMqOrchestrationProviderTest.java @@ -0,0 +1,522 @@ +/* + * TLS-Crawler - A TLS scanning tool to perform large scale scans with the TLS-Scanner + * + * Copyright 2018-2022 Ruhr University Bochum, Paderborn University, and Hackmanit GmbH + * + * Licensed under Apache License, Version 2.0 + * http://www.apache.org/licenses/LICENSE-2.0.txt + */ +package de.rub.nds.crawler.orchestration; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; + +import com.rabbitmq.client.*; +import de.rub.nds.crawler.config.delegate.RabbitMqDelegate; +import de.rub.nds.crawler.constant.JobStatus; +import de.rub.nds.crawler.core.BulkScanWorker; +import de.rub.nds.crawler.data.BulkScan; +import de.rub.nds.crawler.data.ScanJobDescription; +import de.rub.nds.crawler.data.ScanTarget; +import de.rub.nds.scanner.core.config.ScannerDetail; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Path; +import java.security.NoSuchAlgorithmException; +import java.util.Map; +import java.util.concurrent.TimeoutException; +import org.apache.commons.lang3.SerializationUtils; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import org.mockito.*; + +class RabbitMqOrchestrationProviderTest { + + @Mock private ConnectionFactory connectionFactory; + + @Mock private Connection connection; + + @Mock private Channel channel; + + @Mock private RabbitMqDelegate rabbitMqDelegate; + + @Captor private ArgumentCaptor deliverCallbackCaptor; + + @Captor private ArgumentCaptor cancelCallbackCaptor; + + @TempDir Path tempDir; + + private RabbitMqOrchestrationProvider provider; + + @BeforeEach + void setUp() throws IOException, TimeoutException { + MockitoAnnotations.openMocks(this); + when(rabbitMqDelegate.getRabbitMqHost()).thenReturn("localhost"); + when(rabbitMqDelegate.getRabbitMqPort()).thenReturn(5672); + } + + private BulkScan createTestBulkScan(String id, String name) { + de.rub.nds.crawler.data.ScanConfig scanConfig = + new de.rub.nds.crawler.data.ScanConfig(ScannerDetail.NORMAL, 3, 2000) { + @Override + public BulkScanWorker + createWorker( + String bulkScanID, + int parallelConnectionThreads, + int parallelScanThreads) { + return null; + } + }; + BulkScan bulkScan = + new BulkScan( + RabbitMqOrchestrationProviderTest.class, + RabbitMqOrchestrationProviderTest.class, + name, + scanConfig, + System.currentTimeMillis(), + true, + null); + bulkScan.set_id(id); + return bulkScan; + } + + @Test + void testConstructorWithBasicAuth() throws Exception { + try (MockedConstruction mockedFactory = + mockConstruction( + ConnectionFactory.class, + (mock, context) -> { + when(mock.newConnection()).thenReturn(connection); + })) { + when(connection.createChannel()).thenReturn(channel); + when(rabbitMqDelegate.getRabbitMqUser()).thenReturn("user"); + when(rabbitMqDelegate.getRabbitMqPass()).thenReturn("pass"); + + provider = new RabbitMqOrchestrationProvider(rabbitMqDelegate); + + ConnectionFactory factory = mockedFactory.constructed().get(0); + verify(factory).setHost("localhost"); + verify(factory).setPort(5672); + verify(factory).setUsername("user"); + verify(factory).setPassword("pass"); + verify(channel) + .queueDeclare( + eq("scan-job-queue"), eq(false), eq(false), eq(false), any(Map.class)); + } + } + + @Test + void testConstructorWithPasswordFile() throws Exception { + Path passFile = tempDir.resolve("password.txt"); + Files.write(passFile, "filepass".getBytes()); + + try (MockedConstruction mockedFactory = + mockConstruction( + ConnectionFactory.class, + (mock, context) -> { + when(mock.newConnection()).thenReturn(connection); + })) { + when(connection.createChannel()).thenReturn(channel); + when(rabbitMqDelegate.getRabbitMqPassFile()).thenReturn(passFile.toString()); + + provider = new RabbitMqOrchestrationProvider(rabbitMqDelegate); + + ConnectionFactory factory = mockedFactory.constructed().get(0); + verify(factory).setPassword("filepass"); + } + } + + @Test + void testConstructorWithTLS() throws Exception { + try (MockedConstruction mockedFactory = + mockConstruction( + ConnectionFactory.class, + (mock, context) -> { + when(mock.newConnection()).thenReturn(connection); + })) { + when(connection.createChannel()).thenReturn(channel); + when(rabbitMqDelegate.isRabbitMqTLS()).thenReturn(true); + + provider = new RabbitMqOrchestrationProvider(rabbitMqDelegate); + + ConnectionFactory factory = mockedFactory.constructed().get(0); + verify(factory).useSslProtocol(); + } + } + + @Test + void testConstructorWithTLSException() throws Exception { + try (MockedConstruction mockedFactory = + mockConstruction( + ConnectionFactory.class, + (mock, context) -> { + when(mock.newConnection()).thenReturn(connection); + doThrow(new NoSuchAlgorithmException("Test")) + .when(mock) + .useSslProtocol(); + })) { + when(connection.createChannel()).thenReturn(channel); + when(rabbitMqDelegate.isRabbitMqTLS()).thenReturn(true); + + // Should not throw, just log error + assertDoesNotThrow( + () -> provider = new RabbitMqOrchestrationProvider(rabbitMqDelegate)); + } + } + + @Test + void testConstructorConnectionException() throws Exception { + try (MockedConstruction mockedFactory = + mockConstruction( + ConnectionFactory.class, + (mock, context) -> { + when(mock.newConnection()) + .thenThrow(new IOException("Connection failed")); + })) { + + assertThrows( + RuntimeException.class, + () -> provider = new RabbitMqOrchestrationProvider(rabbitMqDelegate)); + } + } + + @Test + void testSubmitScanJob() throws Exception { + setupProvider(); + + ScanTarget target = new ScanTarget(); + target.setHostname("example.com"); + BulkScan bulkScan = + new BulkScan( + RabbitMqOrchestrationProviderTest.class, + RabbitMqOrchestrationProviderTest.class, + "test-scan", + null, + System.currentTimeMillis(), + true, + null); + ScanJobDescription scanJob = + new ScanJobDescription(target, bulkScan, JobStatus.TO_BE_EXECUTED); + + provider.submitScanJob(scanJob); + + verify(channel) + .basicPublish( + eq(""), + eq("scan-job-queue"), + isNull(), + eq(SerializationUtils.serialize(scanJob))); + } + + @Test + void testSubmitScanJobIOException() throws Exception { + setupProvider(); + + ScanTarget target = new ScanTarget(); + target.setHostname("example.com"); + BulkScan bulkScan = createTestBulkScan("bulk-id", "test-scan"); + ScanJobDescription scanJob = + new ScanJobDescription(target, bulkScan, JobStatus.TO_BE_EXECUTED); + doThrow(new IOException("Publish failed")) + .when(channel) + .basicPublish(any(), any(), any(), any()); + + // Should not throw, just log error + assertDoesNotThrow(() -> provider.submitScanJob(scanJob)); + } + + @Test + void testRegisterScanJobConsumer() throws Exception { + setupProvider(); + + ScanJobConsumer consumer = mock(ScanJobConsumer.class); + int prefetchCount = 10; + + provider.registerScanJobConsumer(consumer, prefetchCount); + + verify(channel).basicQos(prefetchCount); + verify(channel) + .basicConsume( + eq("scan-job-queue"), + eq(false), + deliverCallbackCaptor.capture(), + any(CancelCallback.class)); + + // Test the delivery callback + DeliverCallback deliverCallback = deliverCallbackCaptor.getValue(); + ScanTarget target = new ScanTarget(); + target.setHostname("example.com"); + BulkScan bulkScan = + new BulkScan( + RabbitMqOrchestrationProviderTest.class, + RabbitMqOrchestrationProviderTest.class, + "test-scan", + null, + System.currentTimeMillis(), + true, + null); + ScanJobDescription scanJob = + new ScanJobDescription(target, bulkScan, JobStatus.TO_BE_EXECUTED); + + Envelope envelope = mock(Envelope.class); + when(envelope.getDeliveryTag()).thenReturn(123L); + + AMQP.BasicProperties properties = mock(AMQP.BasicProperties.class); + Delivery delivery = + new Delivery(envelope, properties, SerializationUtils.serialize(scanJob)); + + deliverCallback.handle("consumerTag", delivery); + + ArgumentCaptor scanJobCaptor = + ArgumentCaptor.forClass(ScanJobDescription.class); + verify(consumer).consumeScanJob(scanJobCaptor.capture()); + assertEquals("example.com", scanJobCaptor.getValue().getScanTarget().getHostname()); + assertEquals(123L, scanJobCaptor.getValue().getDeliveryTag()); + } + + @Test + void testRegisterScanJobConsumerDeserializationError() throws Exception { + setupProvider(); + + ScanJobConsumer consumer = mock(ScanJobConsumer.class); + + provider.registerScanJobConsumer(consumer, 1); + + verify(channel) + .basicConsume( + eq("scan-job-queue"), + eq(false), + deliverCallbackCaptor.capture(), + any(CancelCallback.class)); + + // Test the delivery callback with bad data + DeliverCallback deliverCallback = deliverCallbackCaptor.getValue(); + + Envelope envelope = mock(Envelope.class); + when(envelope.getDeliveryTag()).thenReturn(123L); + + AMQP.BasicProperties properties = mock(AMQP.BasicProperties.class); + Delivery delivery = new Delivery(envelope, properties, "bad data".getBytes()); + + deliverCallback.handle("consumerTag", delivery); + + verify(channel).basicReject(123L, false); + verify(consumer, never()).consumeScanJob(any()); + } + + @Test + void testRegisterScanJobConsumerIOException() throws Exception { + setupProvider(); + + ScanJobConsumer consumer = mock(ScanJobConsumer.class); + doThrow(new IOException("Register failed")) + .when(channel) + .basicConsume( + any(), anyBoolean(), any(DeliverCallback.class), any(CancelCallback.class)); + + // Should not throw, just log error + assertDoesNotThrow(() -> provider.registerScanJobConsumer(consumer, 1)); + } + + @Test + void testRegisterDoneNotificationConsumer() throws Exception { + setupProvider(); + + BulkScan bulkScan = createTestBulkScan("bulk-123", "test-scan"); + DoneNotificationConsumer consumer = mock(DoneNotificationConsumer.class); + + provider.registerDoneNotificationConsumer(bulkScan, consumer); + + verify(channel) + .queueDeclare( + eq("done-notify-queue_bulk-123"), + eq(false), + eq(false), + eq(true), + any(Map.class)); + verify(channel).basicQos(1); + verify(channel) + .basicConsume( + eq("done-notify-queue_bulk-123"), + eq(true), + deliverCallbackCaptor.capture(), + any(CancelCallback.class)); + + // Test the delivery callback + DeliverCallback deliverCallback = deliverCallbackCaptor.getValue(); + ScanTarget target = new ScanTarget(); + target.setHostname("example.com"); + BulkScan testBulkScan = createTestBulkScan("test-id", "test-scan"); + ScanJobDescription scanJob = + new ScanJobDescription(target, testBulkScan, JobStatus.TO_BE_EXECUTED); + + Envelope envelope = mock(Envelope.class); + AMQP.BasicProperties properties = mock(AMQP.BasicProperties.class); + Delivery delivery = + new Delivery(envelope, properties, SerializationUtils.serialize(scanJob)); + + deliverCallback.handle("consumerTag", delivery); + + ArgumentCaptor scanJobCaptor = + ArgumentCaptor.forClass(ScanJobDescription.class); + verify(consumer).consumeDoneNotification(eq("consumerTag"), scanJobCaptor.capture()); + assertEquals("example.com", scanJobCaptor.getValue().getScanTarget().getHostname()); + } + + @Test + void testRegisterDoneNotificationConsumerQueueAlreadyDeclared() throws Exception { + setupProvider(); + + // First registration + BulkScan bulkScan = createTestBulkScan("bulk-123", "test-scan"); + DoneNotificationConsumer consumer1 = mock(DoneNotificationConsumer.class); + provider.registerDoneNotificationConsumer(bulkScan, consumer1); + + // Second registration with same bulk scan ID + DoneNotificationConsumer consumer2 = mock(DoneNotificationConsumer.class); + provider.registerDoneNotificationConsumer(bulkScan, consumer2); + + // Queue should only be declared once + verify(channel, times(1)) + .queueDeclare( + eq("done-notify-queue_bulk-123"), + eq(false), + eq(false), + eq(true), + any(Map.class)); + } + + @Test + void testRegisterDoneNotificationConsumerIOException() throws Exception { + setupProvider(); + + BulkScan bulkScan = createTestBulkScan("bulk-123", "test-scan"); + DoneNotificationConsumer consumer = mock(DoneNotificationConsumer.class); + + doThrow(new IOException("Queue declare failed")) + .when(channel) + .queueDeclare(any(), anyBoolean(), anyBoolean(), anyBoolean(), any()); + + // Should not throw, just log error + assertDoesNotThrow(() -> provider.registerDoneNotificationConsumer(bulkScan, consumer)); + } + + @Test + void testNotifyOfDoneScanJobMonitored() throws Exception { + setupProvider(); + + ScanTarget target = new ScanTarget(); + target.setHostname("example.com"); + BulkScan bulkScan = createTestBulkScan("bulk-123", "test-scan"); + bulkScan.setMonitored(true); + ScanJobDescription scanJob = new ScanJobDescription(target, bulkScan, JobStatus.SUCCESS); + scanJob.setDeliveryTag(123L); + + provider.notifyOfDoneScanJob(scanJob); + + verify(channel).basicAck(123L, false); + verify(channel) + .queueDeclare( + eq("done-notify-queue_bulk-123"), + eq(false), + eq(false), + eq(true), + any(Map.class)); + verify(channel) + .basicPublish( + eq(""), + eq("done-notify-queue_bulk-123"), + isNull(), + eq(SerializationUtils.serialize(scanJob))); + } + + @Test + void testNotifyOfDoneScanJobNotMonitored() throws Exception { + setupProvider(); + + ScanTarget target = new ScanTarget(); + target.setHostname("example.com"); + BulkScan bulkScan = createTestBulkScan("bulk-123", "test-scan"); + bulkScan.setMonitored(false); + ScanJobDescription scanJob = new ScanJobDescription(target, bulkScan, JobStatus.SUCCESS); + scanJob.setDeliveryTag(123L); + + provider.notifyOfDoneScanJob(scanJob); + + verify(channel).basicAck(123L, false); + verify(channel, never()).basicPublish(any(), any(), any(), any()); + } + + @Test + void testNotifyOfDoneScanJobAckException() throws Exception { + setupProvider(); + + ScanTarget target = new ScanTarget(); + target.setHostname("example.com"); + BulkScan bulkScan = createTestBulkScan("bulk-123", "test-scan"); + bulkScan.setMonitored(false); + ScanJobDescription scanJob = new ScanJobDescription(target, bulkScan, JobStatus.SUCCESS); + scanJob.setDeliveryTag(123L); + + doThrow(new IOException("Ack failed")).when(channel).basicAck(anyLong(), anyBoolean()); + + // Should not throw, just log error + assertDoesNotThrow(() -> provider.notifyOfDoneScanJob(scanJob)); + } + + @Test + void testNotifyOfDoneScanJobPublishException() throws Exception { + setupProvider(); + + ScanTarget target = new ScanTarget(); + target.setHostname("example.com"); + BulkScan bulkScan = createTestBulkScan("bulk-123", "test-scan"); + bulkScan.setMonitored(true); + ScanJobDescription scanJob = new ScanJobDescription(target, bulkScan, JobStatus.SUCCESS); + scanJob.setDeliveryTag(123L); + + doThrow(new IOException("Publish failed")) + .when(channel) + .basicPublish(any(), any(), any(), any()); + + // Should not throw, just log error + assertDoesNotThrow(() -> provider.notifyOfDoneScanJob(scanJob)); + + // Ack should still be sent + verify(channel).basicAck(123L, false); + } + + @Test + void testCloseConnection() throws Exception { + setupProvider(); + + provider.closeConnection(); + + verify(channel).close(); + verify(connection).close(); + } + + @Test + void testCloseConnectionException() throws Exception { + setupProvider(); + + doThrow(new IOException("Close failed")).when(channel).close(); + + // Should not throw, just log error + assertDoesNotThrow(() -> provider.closeConnection()); + } + + private void setupProvider() throws Exception { + try (MockedConstruction mockedFactory = + mockConstruction( + ConnectionFactory.class, + (mock, context) -> { + when(mock.newConnection()).thenReturn(connection); + })) { + when(connection.createChannel()).thenReturn(channel); + provider = new RabbitMqOrchestrationProvider(rabbitMqDelegate); + } + } +} diff --git a/src/test/java/de/rub/nds/crawler/persistence/MongoPersistenceProviderTest.java b/src/test/java/de/rub/nds/crawler/persistence/MongoPersistenceProviderTest.java new file mode 100644 index 0000000..6aeb6db --- /dev/null +++ b/src/test/java/de/rub/nds/crawler/persistence/MongoPersistenceProviderTest.java @@ -0,0 +1,492 @@ +/* + * TLS-Crawler - A TLS scanning tool to perform large scale scans with the TLS-Scanner + * + * Copyright 2018-2022 Ruhr University Bochum, Paderborn University, and Hackmanit GmbH + * + * Licensed under Apache License, Version 2.0 + * http://www.apache.org/licenses/LICENSE-2.0.txt + */ +package de.rub.nds.crawler.persistence; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; + +import com.fasterxml.jackson.databind.JsonSerializer; +import com.fasterxml.jackson.databind.Module; +import com.mongodb.client.*; +import de.rub.nds.crawler.config.delegate.MongoDbDelegate; +import de.rub.nds.crawler.constant.JobStatus; +import de.rub.nds.crawler.core.BulkScanWorker; +import de.rub.nds.crawler.data.BulkScan; +import de.rub.nds.crawler.data.ScanJobDescription; +import de.rub.nds.crawler.data.ScanResult; +import de.rub.nds.crawler.data.ScanTarget; +import de.rub.nds.scanner.core.config.ScannerDetail; +import java.io.IOException; +import java.lang.reflect.Field; +import java.nio.file.Files; +import java.nio.file.Path; +import org.bson.conversions.Bson; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import org.mockito.*; +import org.mongojack.JacksonMongoCollection; + +class MongoPersistenceProviderTest { + + @Mock private MongoDbDelegate mongoDbDelegate; + + @Mock private MongoClient mongoClient; + + @Mock private ClientSession clientSession; + + @Mock private MongoDatabase mongoDatabase; + + @Mock private MongoCollection scanResultCollection; + + @Mock private MongoCollection bulkScanCollection; + + @Mock private JacksonMongoCollection jacksonScanResultCollection; + + @Mock private JacksonMongoCollection jacksonBulkScanCollection; + + @Mock private JsonSerializer mockSerializer; + + @Mock private Module mockModule; + + @TempDir Path tempDir; + + private MongoPersistenceProvider provider; + + @BeforeEach + void setUp() throws Exception { + MockitoAnnotations.openMocks(this); + // Clear static state + Field isInitializedField = MongoPersistenceProvider.class.getDeclaredField("isInitialized"); + isInitializedField.setAccessible(true); + isInitializedField.set(null, false); + + Field serializersField = MongoPersistenceProvider.class.getDeclaredField("serializers"); + serializersField.setAccessible(true); + ((java.util.Set) serializersField.get(null)).clear(); + + Field modulesField = MongoPersistenceProvider.class.getDeclaredField("modules"); + modulesField.setAccessible(true); + ((java.util.Set) modulesField.get(null)).clear(); + + when(mongoDbDelegate.getMongoDbHost()).thenReturn("localhost"); + when(mongoDbDelegate.getMongoDbPort()).thenReturn(27017); + when(mongoDbDelegate.getMongoDbUser()).thenReturn("user"); + when(mongoDbDelegate.getMongoDbPass()).thenReturn("pass"); + when(mongoDbDelegate.getMongoDbAuthSource()).thenReturn("admin"); + } + + @AfterEach + void tearDown() throws Exception { + // Reset static state + Field isInitializedField = MongoPersistenceProvider.class.getDeclaredField("isInitialized"); + isInitializedField.setAccessible(true); + isInitializedField.set(null, false); + } + + private BulkScan createTestBulkScan(String id, String name) { + de.rub.nds.crawler.data.ScanConfig scanConfig = + new de.rub.nds.crawler.data.ScanConfig(ScannerDetail.NORMAL, 3, 2000) { + @Override + public BulkScanWorker + createWorker( + String bulkScanID, + int parallelConnectionThreads, + int parallelScanThreads) { + return null; + } + }; + BulkScan bulkScan = + new BulkScan( + MongoPersistenceProviderTest.class, + MongoPersistenceProviderTest.class, + name, + scanConfig, + System.currentTimeMillis(), + true, + null); + bulkScan.set_id(id); + return bulkScan; + } + + @Test + void testRegisterSerializer() { + assertDoesNotThrow(() -> MongoPersistenceProvider.registerSerializer(mockSerializer)); + } + + @Test + void testRegisterSerializerAfterInitialization() { + try (MockedStatic mockedMongoClients = mockStatic(MongoClients.class)) { + mockedMongoClients + .when(() -> MongoClients.create((String) any())) + .thenReturn(mongoClient); + when(mongoClient.startSession()).thenReturn(clientSession); + + provider = new MongoPersistenceProvider(mongoDbDelegate); + + assertThrows( + RuntimeException.class, + () -> MongoPersistenceProvider.registerSerializer(mockSerializer)); + } + } + + @Test + void testRegisterMultipleSerializers() { + JsonSerializer serializer1 = mock(JsonSerializer.class); + JsonSerializer serializer2 = mock(JsonSerializer.class); + assertDoesNotThrow( + () -> MongoPersistenceProvider.registerSerializer(serializer1, serializer2)); + } + + @Test + void testRegisterModule() { + assertDoesNotThrow(() -> MongoPersistenceProvider.registerModule(mockModule)); + } + + @Test + void testRegisterModuleAfterInitialization() { + try (MockedStatic mockedMongoClients = mockStatic(MongoClients.class)) { + mockedMongoClients + .when(() -> MongoClients.create((String) any())) + .thenReturn(mongoClient); + when(mongoClient.startSession()).thenReturn(clientSession); + + provider = new MongoPersistenceProvider(mongoDbDelegate); + + assertThrows( + RuntimeException.class, + () -> MongoPersistenceProvider.registerModule(mockModule)); + } + } + + @Test + void testRegisterMultipleModules() { + Module module1 = mock(Module.class); + Module module2 = mock(Module.class); + assertDoesNotThrow(() -> MongoPersistenceProvider.registerModule(module1, module2)); + } + + @Test + void testConstructorWithPassword() { + try (MockedStatic mockedMongoClients = mockStatic(MongoClients.class)) { + mockedMongoClients + .when(() -> MongoClients.create((String) any())) + .thenReturn(mongoClient); + when(mongoClient.startSession()).thenReturn(clientSession); + + provider = new MongoPersistenceProvider(mongoDbDelegate); + + assertNotNull(provider); + } + } + + @Test + void testConstructorWithPasswordFile() throws IOException { + Path passFile = tempDir.resolve("password.txt"); + Files.write(passFile, "filepass".getBytes()); + + when(mongoDbDelegate.getMongoDbPass()).thenReturn(null); + when(mongoDbDelegate.getMongoDbPassFile()).thenReturn(passFile.toString()); + + try (MockedStatic mockedMongoClients = mockStatic(MongoClients.class)) { + mockedMongoClients + .when(() -> MongoClients.create((String) any())) + .thenReturn(mongoClient); + when(mongoClient.startSession()).thenReturn(clientSession); + + provider = new MongoPersistenceProvider(mongoDbDelegate); + + assertNotNull(provider); + } + } + + @Test + void testConstructorConnectionException() { + try (MockedStatic mockedMongoClients = mockStatic(MongoClients.class)) { + mockedMongoClients + .when(() -> MongoClients.create((String) any())) + .thenReturn(mongoClient); + when(mongoClient.startSession()).thenThrow(new RuntimeException("Connection failed")); + + assertThrows( + RuntimeException.class, + () -> provider = new MongoPersistenceProvider(mongoDbDelegate)); + } + } + + @Test + void testInsertBulkScan() { + try (MockedStatic mockedMongoClients = mockStatic(MongoClients.class); + MockedStatic mockedJackson = + mockStatic(JacksonMongoCollection.class)) { + + setupMocks(mockedMongoClients, mockedJackson); + + provider = new MongoPersistenceProvider(mongoDbDelegate); + + BulkScan bulkScan = createTestBulkScan("test-id", "test-scan"); + + provider.insertBulkScan(bulkScan); + + verify(jacksonBulkScanCollection).insertOne(bulkScan); + } + } + + @Test + void testInsertBulkScanNull() { + try (MockedStatic mockedMongoClients = mockStatic(MongoClients.class); + MockedStatic mockedJackson = + mockStatic(JacksonMongoCollection.class)) { + + setupMocks(mockedMongoClients, mockedJackson); + + provider = new MongoPersistenceProvider(mongoDbDelegate); + + assertThrows(NullPointerException.class, () -> provider.insertBulkScan(null)); + } + } + + @Test + void testUpdateBulkScan() { + try (MockedStatic mockedMongoClients = mockStatic(MongoClients.class); + MockedStatic mockedJackson = + mockStatic(JacksonMongoCollection.class)) { + + setupMocks(mockedMongoClients, mockedJackson); + + provider = new MongoPersistenceProvider(mongoDbDelegate); + + BulkScan bulkScan = createTestBulkScan("test-id", "test-scan"); + + provider.updateBulkScan(bulkScan); + + verify(jacksonBulkScanCollection).removeById("test-id"); + verify(jacksonBulkScanCollection).insertOne(bulkScan); + } + } + + @Test + void testInsertScanResult() { + try (MockedStatic mockedMongoClients = mockStatic(MongoClients.class); + MockedStatic mockedJackson = + mockStatic(JacksonMongoCollection.class)) { + + setupMocks(mockedMongoClients, mockedJackson); + + provider = new MongoPersistenceProvider(mongoDbDelegate); + + ScanTarget scanTarget = new ScanTarget(); + scanTarget.setHostname("example.com"); + + BulkScan bulkScan = createTestBulkScan("bulk-id", "test-db"); + ScanJobDescription scanJob = + new ScanJobDescription(scanTarget, bulkScan, JobStatus.SUCCESS); + + org.bson.Document resultDoc = new org.bson.Document(); + ScanResult scanResult = new ScanResult(scanJob, resultDoc); + + provider.insertScanResult(scanResult, scanJob); + + verify(jacksonScanResultCollection).insertOne(scanResult); + } + } + + @Test + void testInsertScanResultStatusMismatch() { + try (MockedStatic mockedMongoClients = mockStatic(MongoClients.class); + MockedStatic mockedJackson = + mockStatic(JacksonMongoCollection.class)) { + + setupMocks(mockedMongoClients, mockedJackson); + + provider = new MongoPersistenceProvider(mongoDbDelegate); + + ScanTarget scanTarget = new ScanTarget(); + scanTarget.setHostname("example.com"); + + BulkScan bulkScan = createTestBulkScan("bulk-id", "test-db"); + ScanJobDescription scanJob = + new ScanJobDescription(scanTarget, bulkScan, JobStatus.SUCCESS); + + // Create a result with different status + org.bson.Document resultDoc = new org.bson.Document(); + ScanJobDescription wrongStatusJob = + new ScanJobDescription(scanTarget, bulkScan, JobStatus.ERROR); + ScanResult scanResult = new ScanResult(wrongStatusJob, resultDoc); + + assertThrows( + IllegalArgumentException.class, + () -> provider.insertScanResult(scanResult, scanJob)); + } + } + + @Test + void testInsertScanResultWithException() { + try (MockedStatic mockedMongoClients = mockStatic(MongoClients.class); + MockedStatic mockedJackson = + mockStatic(JacksonMongoCollection.class)) { + + setupMocks(mockedMongoClients, mockedJackson); + + provider = new MongoPersistenceProvider(mongoDbDelegate); + + ScanTarget scanTarget = new ScanTarget(); + scanTarget.setHostname("example.com"); + + BulkScan bulkScan = createTestBulkScan("bulk-id", "test-db"); + ScanJobDescription scanJob = + new ScanJobDescription(scanTarget, bulkScan, JobStatus.SUCCESS); + + org.bson.Document resultDoc = new org.bson.Document(); + ScanResult scanResult = new ScanResult(scanJob, resultDoc); + + // First insertion throws exception + doThrow(new RuntimeException("Serialization error")) + .doNothing() + .when(jacksonScanResultCollection) + .insertOne(any()); + + provider.insertScanResult(scanResult, scanJob); + + // Should insert twice - once with original result (fails), once with error result + verify(jacksonScanResultCollection, times(2)).insertOne(any()); + assertEquals(JobStatus.SERIALIZATION_ERROR, scanJob.getStatus()); + } + } + + @Test + void testInsertScanResultWithSerializationErrorRecursion() { + try (MockedStatic mockedMongoClients = mockStatic(MongoClients.class); + MockedStatic mockedJackson = + mockStatic(JacksonMongoCollection.class)) { + + setupMocks(mockedMongoClients, mockedJackson); + + provider = new MongoPersistenceProvider(mongoDbDelegate); + + ScanTarget scanTarget = new ScanTarget(); + scanTarget.setHostname("example.com"); + + BulkScan bulkScan = createTestBulkScan("bulk-id", "test-db"); + ScanJobDescription scanJob = + new ScanJobDescription(scanTarget, bulkScan, JobStatus.SERIALIZATION_ERROR); + + org.bson.Document resultDoc = new org.bson.Document(); + ScanResult scanResult = new ScanResult(scanJob, resultDoc); + + // Always throw exception + doThrow(new RuntimeException("Serialization error")) + .when(jacksonScanResultCollection) + .insertOne(any()); + + provider.insertScanResult(scanResult, scanJob); + + // Should only try once to avoid infinite recursion + verify(jacksonScanResultCollection, times(1)).insertOne(any()); + assertEquals(JobStatus.INTERNAL_ERROR, scanJob.getStatus()); + } + } + + @Test + void testInitDatabaseCaching() throws Exception { + try (MockedStatic mockedMongoClients = mockStatic(MongoClients.class); + MockedStatic mockedJackson = + mockStatic(JacksonMongoCollection.class)) { + + setupMocks(mockedMongoClients, mockedJackson); + + provider = new MongoPersistenceProvider(mongoDbDelegate); + + BulkScan bulkScan1 = createTestBulkScan("id1", "test-db"); + BulkScan bulkScan2 = createTestBulkScan("id2", "test-db"); // Same DB name + + provider.insertBulkScan(bulkScan1); + provider.insertBulkScan(bulkScan2); + + // Database should only be initialized once due to caching + verify(mongoClient, times(1)).getDatabase("test-db"); + } + } + + @Test + void testResultCollectionIndexCreation() { + try (MockedStatic mockedMongoClients = mockStatic(MongoClients.class); + MockedStatic mockedJackson = + mockStatic(JacksonMongoCollection.class)) { + + setupMocks(mockedMongoClients, mockedJackson); + + provider = new MongoPersistenceProvider(mongoDbDelegate); + + ScanTarget scanTarget = new ScanTarget(); + scanTarget.setHostname("example.com"); + + BulkScan bulkScan = createTestBulkScan("bulk-id", "test-db"); + ScanJobDescription scanJob = + new ScanJobDescription(scanTarget, bulkScan, JobStatus.SUCCESS); + + org.bson.Document resultDoc = new org.bson.Document(); + ScanResult scanResult = new ScanResult(scanJob, resultDoc); + + provider.insertScanResult(scanResult, scanJob); + + // Verify indexes are created + verify(jacksonScanResultCollection, times(4)).createIndex(any(Bson.class)); + } + } + + @Test + void testWithRegisteredSerializersAndModules() { + // Register serializers and modules before initialization + MongoPersistenceProvider.registerSerializer(mockSerializer); + MongoPersistenceProvider.registerModule(mockModule); + + try (MockedStatic mockedMongoClients = mockStatic(MongoClients.class)) { + mockedMongoClients + .when(() -> MongoClients.create((String) any())) + .thenReturn(mongoClient); + when(mongoClient.startSession()).thenReturn(clientSession); + + provider = new MongoPersistenceProvider(mongoDbDelegate); + + assertNotNull(provider); + } + } + + private void setupMocks( + MockedStatic mockedMongoClients, + MockedStatic mockedJackson) { + mockedMongoClients.when(() -> MongoClients.create((String) any())).thenReturn(mongoClient); + when(mongoClient.startSession()).thenReturn(clientSession); + when(mongoClient.getDatabase(anyString())).thenReturn(mongoDatabase); + + JacksonMongoCollection.JacksonMongoCollectionBuilder resultBuilder = + mock(JacksonMongoCollection.JacksonMongoCollectionBuilder.class); + JacksonMongoCollection.JacksonMongoCollectionBuilder bulkScanBuilder = + mock(JacksonMongoCollection.JacksonMongoCollectionBuilder.class); + + mockedJackson + .when(JacksonMongoCollection::builder) + .thenReturn(resultBuilder, bulkScanBuilder); + + when(resultBuilder.withObjectMapper(any())).thenReturn(resultBuilder); + when(resultBuilder.build( + any(MongoDatabase.class), anyString(), eq(ScanResult.class), any())) + .thenReturn(jacksonScanResultCollection); + + when(bulkScanBuilder.withObjectMapper(any())).thenReturn(bulkScanBuilder); + when(bulkScanBuilder.build( + any(MongoDatabase.class), anyString(), eq(BulkScan.class), any())) + .thenReturn(jacksonBulkScanCollection); + + when(jacksonScanResultCollection.createIndex(any(Bson.class))).thenReturn("index"); + } +} diff --git a/src/test/java/de/rub/nds/crawler/targetlist/CruxListProviderTest.java b/src/test/java/de/rub/nds/crawler/targetlist/CruxListProviderTest.java new file mode 100644 index 0000000..0657b4a --- /dev/null +++ b/src/test/java/de/rub/nds/crawler/targetlist/CruxListProviderTest.java @@ -0,0 +1,159 @@ +/* + * TLS-Crawler - A TLS scanning tool to perform large scale scans with the TLS-Scanner + * + * Copyright 2018-2022 Ruhr University Bochum, Paderborn University, and Hackmanit GmbH + * + * Licensed under Apache License, Version 2.0 + * http://www.apache.org/licenses/LICENSE-2.0.txt + */ +package de.rub.nds.crawler.targetlist; + +import static org.junit.jupiter.api.Assertions.*; + +import de.rub.nds.crawler.constant.CruxListNumber; +import java.util.List; +import java.util.stream.Stream; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +class CruxListProviderTest { + + private CruxListProvider provider; + + @BeforeEach + void setUp() { + provider = new CruxListProvider(CruxListNumber.TOP_1k); + } + + @Test + void testGetTargetListFromLines() { + // Test data with various formats + List lines = + List.of( + "https://google.com,1", + "https://facebook.com,2", + "http://example.com,3", // Should be filtered out (not https) + "https://amazon.com,999", + "https://twitter.com,1000", + "https://youtube.com,1001", // Should be filtered out (rank > 1000) + "https://netflix.com,500"); + + List result = provider.getTargetListFromLines(lines.stream()); + + assertEquals(5, result.size()); + assertTrue(result.contains("google.com")); + assertTrue(result.contains("facebook.com")); + assertTrue(result.contains("amazon.com")); + assertTrue(result.contains("twitter.com")); + assertTrue(result.contains("netflix.com")); + assertFalse(result.contains("example.com")); // http filtered out + assertFalse(result.contains("youtube.com")); // rank too high + } + + @Test + void testGetTargetListFromLinesTop5K() { + CruxListProvider provider5k = new CruxListProvider(CruxListNumber.TOP_5K); + + List lines = + List.of( + "https://site1.com,1", + "https://site2.com,4999", + "https://site3.com,5000", + "https://site4.com,5001"); // Should be filtered out + + List result = provider5k.getTargetListFromLines(lines.stream()); + + assertEquals(3, result.size()); + assertTrue(result.contains("site1.com")); + assertTrue(result.contains("site2.com")); + assertTrue(result.contains("site3.com")); + assertFalse(result.contains("site4.com")); + } + + @Test + void testGetTargetListFromLinesWithSubdomains() { + List lines = + List.of( + "https://www.google.com,1", + "https://mail.google.com,2", + "https://subdomain.example.com,100"); + + List result = provider.getTargetListFromLines(lines.stream()); + + assertEquals(3, result.size()); + assertEquals("www.google.com", result.get(0)); + assertEquals("mail.google.com", result.get(1)); + assertEquals("subdomain.example.com", result.get(2)); + } + + @Test + void testGetTargetListFromLinesWithPorts() { + List lines = List.of("https://example.com:8443,1", "https://test.com:443,2"); + + List result = provider.getTargetListFromLines(lines.stream()); + + assertEquals(2, result.size()); + assertEquals("example.com:8443", result.get(0)); + assertEquals("test.com:443", result.get(1)); + } + + @Test + void testGetTargetListFromLinesWithPaths() { + List lines = + List.of("https://example.com/path,1", "https://test.com/path/to/resource,2"); + + List result = provider.getTargetListFromLines(lines.stream()); + + assertEquals(2, result.size()); + assertEquals("example.com/path", result.get(0)); + assertEquals("test.com/path/to/resource", result.get(1)); + } + + @Test + void testGetTargetListFromLinesEmptyInput() { + List result = provider.getTargetListFromLines(Stream.empty()); + assertTrue(result.isEmpty()); + } + + @Test + void testGetTargetListFromLinesInvalidFormat() { + List lines = + List.of( + "invalid-line-without-comma", + "https://valid.com,1", + "another-invalid-line"); + + // Should handle gracefully or throw exception + assertThrows(Exception.class, () -> provider.getTargetListFromLines(lines.stream())); + } + + @Test + void testGetTargetListFromLinesInvalidRankFormat() { + List lines = List.of("https://example.com,not-a-number", "https://valid.com,1"); + + // Should throw NumberFormatException + assertThrows( + NumberFormatException.class, () -> provider.getTargetListFromLines(lines.stream())); + } + + @Test + void testAllCruxListNumbers() { + // Test each enum value + for (CruxListNumber cruxNumber : CruxListNumber.values()) { + CruxListProvider testProvider = new CruxListProvider(cruxNumber); + + List lines = + List.of( + "https://site1.com,1", + String.format("https://site2.com,%d", cruxNumber.getNumber()), + String.format("https://site3.com,%d", cruxNumber.getNumber() + 1)); + + List result = testProvider.getTargetListFromLines(lines.stream()); + + assertEquals(2, result.size()); + assertTrue(result.contains("site1.com")); + assertTrue(result.contains("site2.com")); + assertFalse(result.contains("site3.com")); + } + } +} diff --git a/src/test/java/de/rub/nds/crawler/targetlist/TrancoEmailListProviderTest.java b/src/test/java/de/rub/nds/crawler/targetlist/TrancoEmailListProviderTest.java new file mode 100644 index 0000000..09e3648 --- /dev/null +++ b/src/test/java/de/rub/nds/crawler/targetlist/TrancoEmailListProviderTest.java @@ -0,0 +1,258 @@ +/* + * TLS-Crawler - A TLS scanning tool to perform large scale scans with the TLS-Scanner + * + * Copyright 2018-2022 Ruhr University Bochum, Paderborn University, and Hackmanit GmbH + * + * Licensed under Apache License, Version 2.0 + * http://www.apache.org/licenses/LICENSE-2.0.txt + */ +package de.rub.nds.crawler.targetlist; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; + +import java.util.List; +import javax.naming.NamingException; +import javax.naming.directory.Attribute; +import javax.naming.directory.Attributes; +import javax.naming.directory.BasicAttribute; +import javax.naming.directory.BasicAttributes; +import javax.naming.directory.InitialDirContext; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockedConstruction; +import org.mockito.MockitoAnnotations; + +class TrancoEmailListProviderTest { + + @Mock private ITargetListProvider trancoList; + + private TrancoEmailListProvider provider; + + @BeforeEach + void setUp() { + MockitoAnnotations.openMocks(this); + provider = new TrancoEmailListProvider(trancoList); + } + + @Test + void testGetTargetList() throws NamingException { + // Setup mock tranco list + when(trancoList.getTargetList()).thenReturn(List.of("1,google.com", "2,facebook.com")); + + // Mock InitialDirContext + try (MockedConstruction mockedDirContext = + mockConstruction(InitialDirContext.class)) { + + InitialDirContext mockDirContext = mockedDirContext.constructed().get(0); + + // Setup MX records for google.com + Attributes googleAttributes = new BasicAttributes(); + Attribute googleMX = new BasicAttribute("MX"); + googleMX.add("10 aspmx.l.google.com."); + googleMX.add("20 alt1.aspmx.l.google.com."); + googleAttributes.put(googleMX); + when(mockDirContext.getAttributes("dns:/google.com", new String[] {"MX"})) + .thenReturn(googleAttributes); + + // Setup MX records for facebook.com + Attributes facebookAttributes = new BasicAttributes(); + Attribute facebookMX = new BasicAttribute("MX"); + facebookMX.add("10 smtpin.vvv.facebook.com."); + facebookAttributes.put(facebookMX); + when(mockDirContext.getAttributes("dns:/facebook.com", new String[] {"MX"})) + .thenReturn(facebookAttributes); + + // Execute + List result = provider.getTargetList(); + + // Verify + assertEquals(3, result.size()); + assertTrue(result.contains("aspmx.l.google.com.")); + assertTrue(result.contains("alt1.aspmx.l.google.com.")); + assertTrue(result.contains("smtpin.vvv.facebook.com.")); + } + } + + @Test + void testGetTargetListWithNoMXRecord() throws NamingException { + // Setup mock tranco list + when(trancoList.getTargetList()).thenReturn(List.of("1,example.com")); + + // Mock InitialDirContext + try (MockedConstruction mockedDirContext = + mockConstruction(InitialDirContext.class)) { + + InitialDirContext mockDirContext = mockedDirContext.constructed().get(0); + + // Setup no MX records + Attributes emptyAttributes = new BasicAttributes(); + when(mockDirContext.getAttributes("dns:/example.com", new String[] {"MX"})) + .thenReturn(emptyAttributes); + + // Execute + List result = provider.getTargetList(); + + // Verify + assertTrue(result.isEmpty()); + } + } + + @Test + void testGetTargetListWithNamingException() throws NamingException { + // Setup mock tranco list + when(trancoList.getTargetList()).thenReturn(List.of("1,badhost.com")); + + // Mock InitialDirContext + try (MockedConstruction mockedDirContext = + mockConstruction(InitialDirContext.class)) { + + InitialDirContext mockDirContext = mockedDirContext.constructed().get(0); + + // Throw NamingException + when(mockDirContext.getAttributes(anyString(), any(String[].class))) + .thenThrow(new NamingException("DNS lookup failed")); + + // Execute + List result = provider.getTargetList(); + + // Verify - should handle exception gracefully + assertTrue(result.isEmpty()); + } + } + + @Test + void testGetTargetListWithDuplicateMXRecords() throws NamingException { + // Setup mock tranco list with multiple domains + when(trancoList.getTargetList()) + .thenReturn(List.of("1,site1.com", "2,site2.com", "3,site3.com")); + + // Mock InitialDirContext + try (MockedConstruction mockedDirContext = + mockConstruction(InitialDirContext.class)) { + + InitialDirContext mockDirContext = mockedDirContext.constructed().get(0); + + // All sites use the same MX server + Attributes sameAttributes = new BasicAttributes(); + Attribute sameMX = new BasicAttribute("MX"); + sameMX.add("10 mail.shared.com."); + sameAttributes.put(sameMX); + + when(mockDirContext.getAttributes(anyString(), any(String[].class))) + .thenReturn(sameAttributes); + + // Execute + List result = provider.getTargetList(); + + // Verify - duplicates should be removed + assertEquals(1, result.size()); + assertEquals("mail.shared.com.", result.get(0)); + } + } + + @Test + void testGetTargetListInitialDirContextCreationFails() throws NamingException { + // Setup mock tranco list + when(trancoList.getTargetList()).thenReturn(List.of("1,google.com")); + + // Mock InitialDirContext constructor to throw + try (MockedConstruction mockedDirContext = + mockConstruction( + InitialDirContext.class, + (mock, context) -> { + throw new NamingException("Cannot create context"); + })) { + + // Execute + List result = provider.getTargetList(); + + // Verify - should handle exception gracefully + assertTrue(result.isEmpty()); + } + } + + @Test + void testGetTargetListWithVariousHostFormats() throws NamingException { + // Setup mock tranco list with various formats + when(trancoList.getTargetList()) + .thenReturn( + List.of( + "example.com", // Without rank + "1,site.com", // With rank + "2,subdomain.site.com" // With subdomain + )); + + // Mock InitialDirContext + try (MockedConstruction mockedDirContext = + mockConstruction(InitialDirContext.class)) { + + InitialDirContext mockDirContext = mockedDirContext.constructed().get(0); + + // Setup different MX records + Attributes attr1 = new BasicAttributes(); + Attribute mx1 = new BasicAttribute("MX"); + mx1.add("10 mail1.example.com."); + attr1.put(mx1); + when(mockDirContext.getAttributes("dns:/example.com", new String[] {"MX"})) + .thenReturn(attr1); + + Attributes attr2 = new BasicAttributes(); + Attribute mx2 = new BasicAttribute("MX"); + mx2.add("10 mail2.site.com."); + attr2.put(mx2); + when(mockDirContext.getAttributes("dns:/site.com", new String[] {"MX"})) + .thenReturn(attr2); + + Attributes attr3 = new BasicAttributes(); + Attribute mx3 = new BasicAttribute("MX"); + mx3.add("10 mail3.subdomain.site.com."); + attr3.put(mx3); + when(mockDirContext.getAttributes("dns:/subdomain.site.com", new String[] {"MX"})) + .thenReturn(attr3); + + // Execute + List result = provider.getTargetList(); + + // Verify + assertEquals(3, result.size()); + assertTrue(result.contains("mail1.example.com.")); + assertTrue(result.contains("mail2.site.com.")); + assertTrue(result.contains("mail3.subdomain.site.com.")); + } + } + + @Test + void testGetTargetListWithMultiplePriorityMX() throws NamingException { + // Setup mock tranco list + when(trancoList.getTargetList()).thenReturn(List.of("1,example.com")); + + // Mock InitialDirContext + try (MockedConstruction mockedDirContext = + mockConstruction(InitialDirContext.class)) { + + InitialDirContext mockDirContext = mockedDirContext.constructed().get(0); + + // Setup MX records with different priorities + Attributes attributes = new BasicAttributes(); + Attribute mxRecords = new BasicAttribute("MX"); + mxRecords.add("10 primary.example.com."); + mxRecords.add("20 secondary.example.com."); + mxRecords.add("30 tertiary.example.com."); + attributes.put(mxRecords); + when(mockDirContext.getAttributes("dns:/example.com", new String[] {"MX"})) + .thenReturn(attributes); + + // Execute + List result = provider.getTargetList(); + + // Verify - all MX records should be included + assertEquals(3, result.size()); + assertTrue(result.contains("primary.example.com.")); + assertTrue(result.contains("secondary.example.com.")); + assertTrue(result.contains("tertiary.example.com.")); + } + } +} diff --git a/src/test/java/de/rub/nds/crawler/targetlist/TrancoListProviderTest.java b/src/test/java/de/rub/nds/crawler/targetlist/TrancoListProviderTest.java new file mode 100644 index 0000000..1c66a9e --- /dev/null +++ b/src/test/java/de/rub/nds/crawler/targetlist/TrancoListProviderTest.java @@ -0,0 +1,119 @@ +/* + * TLS-Crawler - A TLS scanning tool to perform large scale scans with the TLS-Scanner + * + * Copyright 2018-2022 Ruhr University Bochum, Paderborn University, and Hackmanit GmbH + * + * Licensed under Apache License, Version 2.0 + * http://www.apache.org/licenses/LICENSE-2.0.txt + */ +package de.rub.nds.crawler.targetlist; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; + +import java.io.*; +import java.net.URL; +import java.nio.channels.Channels; +import java.nio.channels.FileChannel; +import java.nio.channels.ReadableByteChannel; +import java.nio.file.Path; +import java.util.List; +import java.util.zip.ZipEntry; +import java.util.zip.ZipOutputStream; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import org.mockito.MockedConstruction; +import org.mockito.MockedStatic; + +class TrancoListProviderTest { + + @TempDir Path tempDir; + + private TrancoListProvider provider; + + @BeforeEach + void setUp() { + provider = new TrancoListProvider(3); + } + + @Test + void testGetTargetList() throws Exception { + // Create test data + String csvContent = + "1,google.com\n2,facebook.com\n3,amazon.com\n4,twitter.com\n5,youtube.com"; + + // Create a zip file with the CSV content + Path zipFile = tempDir.resolve("test.zip"); + try (ZipOutputStream zos = new ZipOutputStream(new FileOutputStream(zipFile.toFile()))) { + ZipEntry entry = new ZipEntry("tranco-1m.csv"); + zos.putNextEntry(entry); + zos.write(csvContent.getBytes()); + zos.closeEntry(); + } + + // Mock URL and file operations + try (MockedStatic mockedChannels = mockStatic(Channels.class); + MockedConstruction mockedURL = mockConstruction(URL.class); + MockedConstruction mockedFOS = + mockConstruction(FileOutputStream.class)) { + + // Setup mocks + URL mockUrl = mockedURL.constructed().get(0); + InputStream mockInputStream = new FileInputStream(zipFile.toFile()); + when(mockUrl.openStream()).thenReturn(mockInputStream); + + ReadableByteChannel mockChannel = mock(ReadableByteChannel.class); + mockedChannels + .when(() -> Channels.newChannel(any(InputStream.class))) + .thenReturn(mockChannel); + + FileOutputStream mockFos = mockedFOS.constructed().get(0); + FileChannel mockFileChannel = mock(FileChannel.class); + when(mockFos.getChannel()).thenReturn(mockFileChannel); + when(mockFileChannel.transferFrom(any(), anyLong(), anyLong())).thenReturn(0L); + + // Execute + List targets = provider.getTargetList(); + + // Verify + assertEquals(3, targets.size()); + assertEquals("1,google.com", targets.get(0)); + assertEquals("2,facebook.com", targets.get(1)); + assertEquals("3,amazon.com", targets.get(2)); + } + } + + @Test + void testGetTargetListFromLines() { + // Test the abstract method implementation + List lines = + List.of( + "1,google.com", + "2,facebook.com", + "3,amazon.com", + "4,twitter.com", + "5,youtube.com"); + + List result = provider.getTargetListFromLines(lines.stream()); + + assertEquals(3, result.size()); + assertEquals("1,google.com", result.get(0)); + assertEquals("2,facebook.com", result.get(1)); + assertEquals("3,amazon.com", result.get(2)); + } + + @Test + void testGetTargetListFromLinesFewerThanRequested() { + // Test when there are fewer lines than requested + List lines = List.of("1,google.com", "2,facebook.com"); + + TrancoListProvider provider = new TrancoListProvider(5); + List result = provider.getTargetListFromLines(lines.stream()); + + assertEquals(2, result.size()); + assertEquals("1,google.com", result.get(0)); + assertEquals("2,facebook.com", result.get(1)); + } +} diff --git a/src/test/java/de/rub/nds/crawler/targetlist/ZipFileProviderTest.java b/src/test/java/de/rub/nds/crawler/targetlist/ZipFileProviderTest.java new file mode 100644 index 0000000..f0ce86f --- /dev/null +++ b/src/test/java/de/rub/nds/crawler/targetlist/ZipFileProviderTest.java @@ -0,0 +1,299 @@ +/* + * TLS-Crawler - A TLS scanning tool to perform large scale scans with the TLS-Scanner + * + * Copyright 2018-2023 Ruhr University Bochum, Paderborn University, and Hackmanit GmbH + * + * Licensed under Apache License, Version 2.0 + * http://www.apache.org/licenses/LICENSE-2.0.txt + */ +package de.rub.nds.crawler.targetlist; + +import static org.junit.jupiter.api.Assertions.*; +import static org.mockito.ArgumentMatchers.*; +import static org.mockito.Mockito.*; + +import java.io.*; +import java.net.URL; +import java.nio.channels.Channels; +import java.nio.channels.ReadableByteChannel; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.List; +import java.util.stream.Stream; +import java.util.zip.GZIPOutputStream; +import java.util.zip.ZipEntry; +import java.util.zip.ZipOutputStream; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; +import org.mockito.MockedConstruction; +import org.mockito.MockedStatic; + +class ZipFileProviderTest { + + @TempDir Path tempDir; + + private TestZipFileProvider provider; + + private static class TestZipFileProvider extends ZipFileProvider { + private final List mockTargets; + + public TestZipFileProvider( + int number, + String sourceUrl, + String zipFilename, + String outputFile, + String listName, + List mockTargets) { + super(number, sourceUrl, zipFilename, outputFile, listName); + this.mockTargets = mockTargets; + } + + @Override + protected List getTargetListFromLines(Stream lines) { + return mockTargets != null ? mockTargets : lines.limit(number).toList(); + } + } + + @BeforeEach + void setUp() { + provider = + new TestZipFileProvider( + 5, + "http://example.com/test.zip", + tempDir.resolve("test.zip").toString(), + tempDir.resolve("test.csv").toString(), + "TestList", + null); + } + + @Test + void testGetTargetListWithZipFile() throws Exception { + // Create test data + String csvContent = "host1.com\nhost2.com\nhost3.com\nhost4.com\nhost5.com\nhost6.com"; + + // Create a zip file with the CSV content + Path realZipFile = tempDir.resolve("real.zip"); + try (ZipOutputStream zos = + new ZipOutputStream(new FileOutputStream(realZipFile.toFile()))) { + ZipEntry entry = new ZipEntry("test.csv"); + zos.putNextEntry(entry); + zos.write(csvContent.getBytes()); + zos.closeEntry(); + } + + // Mock URL download + try (MockedConstruction mockedURL = mockConstruction(URL.class); + MockedStatic mockedChannels = mockStatic(Channels.class); + MockedConstruction mockedFOS = + mockConstruction(FileOutputStream.class)) { + + // Setup URL mock + URL mockUrl = mockedURL.constructed().get(0); + InputStream mockInputStream = new FileInputStream(realZipFile.toFile()); + when(mockUrl.openStream()).thenReturn(mockInputStream); + + // Setup channel mock + ReadableByteChannel mockChannel = mock(ReadableByteChannel.class); + mockedChannels + .when(() -> Channels.newChannel(any(InputStream.class))) + .thenReturn(mockChannel); + + // Setup FileOutputStream mock to actually write the file + FileOutputStream realFos = new FileOutputStream(tempDir.resolve("test.zip").toFile()); + Files.copy(realZipFile, realFos); + realFos.close(); + + // Execute + List targets = provider.getTargetList(); + + // Verify + assertEquals(5, targets.size()); + assertEquals("host1.com", targets.get(0)); + assertEquals("host5.com", targets.get(4)); + + // Verify files are deleted + assertFalse(Files.exists(tempDir.resolve("test.zip"))); + assertFalse(Files.exists(tempDir.resolve("test.csv"))); + } + } + + @Test + void testGetTargetListWithGzipFile() throws Exception { + // Create test data + String csvContent = "host1.com\nhost2.com\nhost3.com"; + + // Create a gzip file with the CSV content + Path gzipFile = tempDir.resolve("test.csv.gz"); + try (GZIPOutputStream gos = new GZIPOutputStream(new FileOutputStream(gzipFile.toFile()))) { + gos.write(csvContent.getBytes()); + } + + // Create provider for gzip file + TestZipFileProvider gzipProvider = + new TestZipFileProvider( + 3, + "http://example.com/test.csv.gz", + gzipFile.toString(), + tempDir.resolve("test.csv").toString(), + "TestList", + null); + + // Mock URL download + try (MockedConstruction mockedURL = mockConstruction(URL.class); + MockedStatic mockedChannels = mockStatic(Channels.class)) { + + // Setup URL mock + URL mockUrl = mockedURL.constructed().get(0); + InputStream mockInputStream = new FileInputStream(gzipFile.toFile()); + when(mockUrl.openStream()).thenReturn(mockInputStream); + + // Setup channel mock to do nothing (file already exists) + ReadableByteChannel mockChannel = mock(ReadableByteChannel.class); + mockedChannels + .when(() -> Channels.newChannel(any(InputStream.class))) + .thenReturn(mockChannel); + + // Execute + List targets = gzipProvider.getTargetList(); + + // Verify + assertEquals(3, targets.size()); + assertEquals("host1.com", targets.get(0)); + assertEquals("host3.com", targets.get(2)); + } + } + + @Test + void testGetTargetListDownloadError() throws Exception { + // Mock URL download to fail + try (MockedConstruction mockedURL = mockConstruction(URL.class)) { + URL mockUrl = mockedURL.constructed().get(0); + when(mockUrl.openStream()).thenThrow(new IOException("Download failed")); + + // Create empty zip file so unzip phase can proceed + Path zipFile = tempDir.resolve("test.zip"); + try (ZipOutputStream zos = + new ZipOutputStream(new FileOutputStream(zipFile.toFile()))) { + ZipEntry entry = new ZipEntry("test.csv"); + zos.putNextEntry(entry); + zos.write("host1.com".getBytes()); + zos.closeEntry(); + } + + // Execute - should not throw + List targets = provider.getTargetList(); + + // Should still return results from existing file + assertNotNull(targets); + } + } + + @Test + void testGetTargetListUnzipError() throws Exception { + // Create invalid zip file + Path invalidZipFile = tempDir.resolve("test.zip"); + Files.write(invalidZipFile, "This is not a zip file".getBytes()); + + // Mock URL download + try (MockedConstruction mockedURL = mockConstruction(URL.class); + MockedStatic mockedChannels = mockStatic(Channels.class)) { + + URL mockUrl = mockedURL.constructed().get(0); + InputStream mockInputStream = new ByteArrayInputStream("dummy".getBytes()); + when(mockUrl.openStream()).thenReturn(mockInputStream); + + ReadableByteChannel mockChannel = mock(ReadableByteChannel.class); + mockedChannels + .when(() -> Channels.newChannel(any(InputStream.class))) + .thenReturn(mockChannel); + + // Execute - should throw RuntimeException + assertThrows(RuntimeException.class, () -> provider.getTargetList()); + } + } + + @Test + void testGetTargetListReadFileError() throws Exception { + // Create provider with non-existent output file + TestZipFileProvider errorProvider = + new TestZipFileProvider( + 5, + "http://example.com/test.zip", + tempDir.resolve("test.zip").toString(), + "/non/existent/path/test.csv", + "TestList", + null); + + // Create valid zip file + Path zipFile = tempDir.resolve("test.zip"); + try (ZipOutputStream zos = new ZipOutputStream(new FileOutputStream(zipFile.toFile()))) { + ZipEntry entry = new ZipEntry("test.csv"); + zos.putNextEntry(entry); + zos.write("host1.com".getBytes()); + zos.closeEntry(); + } + + // Mock URL download + try (MockedConstruction mockedURL = mockConstruction(URL.class); + MockedStatic mockedChannels = mockStatic(Channels.class)) { + + URL mockUrl = mockedURL.constructed().get(0); + InputStream mockInputStream = new ByteArrayInputStream("dummy".getBytes()); + when(mockUrl.openStream()).thenReturn(mockInputStream); + + ReadableByteChannel mockChannel = mock(ReadableByteChannel.class); + mockedChannels + .when(() -> Channels.newChannel(any(InputStream.class))) + .thenReturn(mockChannel); + + // Execute - should throw RuntimeException + assertThrows(RuntimeException.class, () -> errorProvider.getTargetList()); + } + } + + @Test + void testDeleteFileErrors() throws Exception { + // Create files that will fail to delete + Path zipFile = tempDir.resolve("test.zip"); + Path csvFile = tempDir.resolve("test.csv"); + + // Create zip file + try (ZipOutputStream zos = new ZipOutputStream(new FileOutputStream(zipFile.toFile()))) { + ZipEntry entry = new ZipEntry("test.csv"); + zos.putNextEntry(entry); + zos.write("host1.com".getBytes()); + zos.closeEntry(); + } + + // Mock file deletion to fail + try (MockedStatic mockedFiles = mockStatic(Files.class, CALLS_REAL_METHODS); + MockedConstruction mockedURL = mockConstruction(URL.class); + MockedStatic mockedChannels = mockStatic(Channels.class)) { + + // Mock file operations + mockedFiles + .when(() -> Files.delete(zipFile)) + .thenThrow(new IOException("Cannot delete zip")); + mockedFiles + .when(() -> Files.delete(csvFile)) + .thenThrow(new IOException("Cannot delete csv")); + + // Allow reading + mockedFiles.when(() -> Files.lines(csvFile)).thenReturn(Stream.of("host1.com")); + + URL mockUrl = mockedURL.constructed().get(0); + InputStream mockInputStream = new ByteArrayInputStream("dummy".getBytes()); + when(mockUrl.openStream()).thenReturn(mockInputStream); + + ReadableByteChannel mockChannel = mock(ReadableByteChannel.class); + mockedChannels + .when(() -> Channels.newChannel(any(InputStream.class))) + .thenReturn(mockChannel); + + // Execute - should not throw despite delete errors + List targets = provider.getTargetList(); + assertNotNull(targets); + } + } +} diff --git a/src/test/java/de/rub/nds/crawler/util/CanceallableThreadPoolExecutorTest.java b/src/test/java/de/rub/nds/crawler/util/CanceallableThreadPoolExecutorTest.java new file mode 100644 index 0000000..e98ea6d --- /dev/null +++ b/src/test/java/de/rub/nds/crawler/util/CanceallableThreadPoolExecutorTest.java @@ -0,0 +1,402 @@ +/* + * TLS-Crawler - A TLS scanning tool to perform large scale scans with the TLS-Scanner + * + * Copyright 2018-2022 Ruhr University Bochum, Paderborn University, and Hackmanit GmbH + * + * Licensed under Apache License, Version 2.0 + * http://www.apache.org/licenses/LICENSE-2.0.txt + */ +package de.rub.nds.crawler.util; + +import static org.junit.jupiter.api.Assertions.*; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.*; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +class CanceallableThreadPoolExecutorTest { + + private CanceallableThreadPoolExecutor executor; + + @BeforeEach + void setUp() { + executor = + new CanceallableThreadPoolExecutor( + 2, // corePoolSize + 4, // maximumPoolSize + 60L, // keepAliveTime + TimeUnit.SECONDS, + new LinkedBlockingQueue<>()); + } + + @AfterEach + void tearDown() throws InterruptedException { + executor.shutdownNow(); + assertTrue(executor.awaitTermination(5, TimeUnit.SECONDS)); + } + + @Test + void testConstructorWithThreadFactory() { + ThreadFactory threadFactory = Executors.defaultThreadFactory(); + CanceallableThreadPoolExecutor executorWithFactory = + new CanceallableThreadPoolExecutor( + 2, 4, 60L, TimeUnit.SECONDS, new LinkedBlockingQueue<>(), threadFactory); + + assertNotNull(executorWithFactory); + executorWithFactory.shutdown(); + } + + @Test + void testConstructorWithRejectedExecutionHandler() { + RejectedExecutionHandler handler = new ThreadPoolExecutor.AbortPolicy(); + CanceallableThreadPoolExecutor executorWithHandler = + new CanceallableThreadPoolExecutor( + 2, 4, 60L, TimeUnit.SECONDS, new LinkedBlockingQueue<>(), handler); + + assertNotNull(executorWithHandler); + executorWithHandler.shutdown(); + } + + @Test + void testConstructorWithThreadFactoryAndHandler() { + ThreadFactory threadFactory = Executors.defaultThreadFactory(); + RejectedExecutionHandler handler = new ThreadPoolExecutor.AbortPolicy(); + CanceallableThreadPoolExecutor executorFull = + new CanceallableThreadPoolExecutor( + 2, + 4, + 60L, + TimeUnit.SECONDS, + new LinkedBlockingQueue<>(), + threadFactory, + handler); + + assertNotNull(executorFull); + executorFull.shutdown(); + } + + @Test + void testSubmitCallable() throws Exception { + Callable task = () -> "test result"; + + Future future = executor.submit(task); + + assertEquals("test result", future.get(1, TimeUnit.SECONDS)); + assertInstanceOf(CancellableFuture.class, future); + } + + @Test + void testSubmitRunnable() throws Exception { + AtomicBoolean executed = new AtomicBoolean(false); + Runnable task = () -> executed.set(true); + + Future future = executor.submit(task); + + future.get(1, TimeUnit.SECONDS); + assertTrue(executed.get()); + assertInstanceOf(CancellableFuture.class, future); + } + + @Test + void testSubmitRunnableWithResult() throws Exception { + AtomicBoolean executed = new AtomicBoolean(false); + Runnable task = () -> executed.set(true); + String result = "test result"; + + Future future = executor.submit(task, result); + + assertEquals(result, future.get(1, TimeUnit.SECONDS)); + assertTrue(executed.get()); + assertInstanceOf(CancellableFuture.class, future); + } + + @Test + void testCancellableFutureCreation() throws Exception { + Callable task = () -> "test"; + + Future future = executor.submit(task); + + assertInstanceOf(CancellableFuture.class, future); + } + + @Test + void testExecuteRunnable() throws Exception { + CountDownLatch latch = new CountDownLatch(1); + AtomicBoolean executed = new AtomicBoolean(false); + + executor.execute( + () -> { + executed.set(true); + latch.countDown(); + }); + + assertTrue(latch.await(1, TimeUnit.SECONDS)); + assertTrue(executed.get()); + } + + @Test + void testMultipleTaskSubmission() throws Exception { + int taskCount = 20; + CountDownLatch latch = new CountDownLatch(taskCount); + AtomicInteger completedCount = new AtomicInteger(0); + List> futures = new ArrayList<>(); + + for (int i = 0; i < taskCount; i++) { + final int taskId = i; + Future future = + executor.submit( + () -> { + try { + Thread.sleep(10); + completedCount.incrementAndGet(); + return taskId; + } finally { + latch.countDown(); + } + }); + futures.add(future); + } + + assertTrue(latch.await(5, TimeUnit.SECONDS)); + assertEquals(taskCount, completedCount.get()); + + // Verify all futures completed successfully and are CancellableFutures + for (int i = 0; i < taskCount; i++) { + assertEquals(i, futures.get(i).get()); + assertInstanceOf(CancellableFuture.class, futures.get(i)); + } + } + + @Test + void testTaskRejection() { + // Create executor with limited queue + executor = + new CanceallableThreadPoolExecutor( + 1, // corePoolSize + 1, // maximumPoolSize + 60L, + TimeUnit.SECONDS, + new LinkedBlockingQueue<>(1)); // queue size 1 + + CountDownLatch blockingLatch = new CountDownLatch(1); + + // Submit blocking task to fill the thread + executor.submit( + () -> { + try { + blockingLatch.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + return null; + }); + + // Submit task to fill the queue + executor.submit(() -> "queued"); + + // This should be rejected + assertThrows( + RejectedExecutionException.class, + () -> { + executor.submit(() -> "rejected"); + }); + + blockingLatch.countDown(); + } + + @Test + void testShutdown() throws Exception { + AtomicBoolean taskCompleted = new AtomicBoolean(false); + CountDownLatch latch = new CountDownLatch(1); + + Future future = + executor.submit( + () -> { + latch.await(); + taskCompleted.set(true); + return null; + }); + + executor.shutdown(); + assertTrue(executor.isShutdown()); + + // Should not accept new tasks + assertThrows( + RejectedExecutionException.class, + () -> { + executor.submit(() -> "new task"); + }); + + // Let the submitted task complete + latch.countDown(); + future.get(1, TimeUnit.SECONDS); + assertTrue(taskCompleted.get()); + + assertTrue(executor.awaitTermination(1, TimeUnit.SECONDS)); + assertTrue(executor.isTerminated()); + } + + @Test + void testShutdownNow() throws Exception { + CountDownLatch startLatch = new CountDownLatch(1); + AtomicBoolean interrupted = new AtomicBoolean(false); + + Future future = + executor.submit( + () -> { + try { + startLatch.countDown(); + Thread.sleep(5000); + } catch (InterruptedException e) { + interrupted.set(true); + Thread.currentThread().interrupt(); + } + }); + + // Wait for task to start + assertTrue(startLatch.await(1, TimeUnit.SECONDS)); + + List pendingTasks = executor.shutdownNow(); + assertTrue(executor.isShutdown()); + + // The running task should be interrupted + assertThrows( + CancellationException.class, + () -> { + future.get(1, TimeUnit.SECONDS); + }); + + assertTrue(interrupted.get()); + assertTrue(executor.awaitTermination(2, TimeUnit.SECONDS)); + } + + @Test + void testConcurrentTaskExecution() throws Exception { + int corePoolSize = 2; + executor = + new CanceallableThreadPoolExecutor( + corePoolSize, + corePoolSize, + 60L, + TimeUnit.SECONDS, + new LinkedBlockingQueue<>()); + + CountDownLatch startLatch = new CountDownLatch(corePoolSize); + CountDownLatch endLatch = new CountDownLatch(corePoolSize); + AtomicInteger concurrentCount = new AtomicInteger(0); + AtomicInteger maxConcurrent = new AtomicInteger(0); + + for (int i = 0; i < corePoolSize; i++) { + executor.submit( + () -> { + startLatch.countDown(); + int current = concurrentCount.incrementAndGet(); + maxConcurrent.updateAndGet(max -> Math.max(max, current)); + try { + Thread.sleep(100); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + concurrentCount.decrementAndGet(); + endLatch.countDown(); + return null; + }); + } + + assertTrue(startLatch.await(1, TimeUnit.SECONDS)); + assertTrue(endLatch.await(2, TimeUnit.SECONDS)); + + // All tasks should have run concurrently + assertEquals(corePoolSize, maxConcurrent.get()); + } + + @Test + void testCancellableFutureIsCancelledCorrectly() throws Exception { + CountDownLatch latch = new CountDownLatch(1); + AtomicBoolean wasInterrupted = new AtomicBoolean(false); + + Future future = + executor.submit( + () -> { + try { + latch.await(); // Wait indefinitely + return "completed"; + } catch (InterruptedException e) { + wasInterrupted.set(true); + Thread.currentThread().interrupt(); + throw e; + } + }); + + // Cancel the future + assertTrue(future.cancel(true)); + assertTrue(future.isCancelled()); + + // The task should have been interrupted + Thread.sleep(100); // Give it time to process the cancellation + assertTrue(wasInterrupted.get()); + + latch.countDown(); // Clean up + } + + @Test + void testInvokeAll() throws Exception { + List> tasks = List.of(() -> "task1", () -> "task2", () -> "task3"); + + List> futures = executor.invokeAll(tasks); + + assertEquals(3, futures.size()); + for (int i = 0; i < futures.size(); i++) { + assertEquals("task" + (i + 1), futures.get(i).get()); + assertInstanceOf(CancellableFuture.class, futures.get(i)); + } + } + + @Test + void testInvokeAny() throws Exception { + List> tasks = + List.of( + () -> { + Thread.sleep(100); + return "slow"; + }, + () -> "fast", + () -> { + Thread.sleep(200); + return "slower"; + }); + + String result = executor.invokeAny(tasks); + + assertEquals("fast", result); + } + + @Test + void testCorePoolSize() { + assertEquals(2, executor.getCorePoolSize()); + + executor.setCorePoolSize(3); + assertEquals(3, executor.getCorePoolSize()); + } + + @Test + void testMaximumPoolSize() { + assertEquals(4, executor.getMaximumPoolSize()); + + executor.setMaximumPoolSize(5); + assertEquals(5, executor.getMaximumPoolSize()); + } + + @Test + void testKeepAliveTime() { + assertEquals(60L, executor.getKeepAliveTime(TimeUnit.SECONDS)); + + executor.setKeepAliveTime(120L, TimeUnit.SECONDS); + assertEquals(120L, executor.getKeepAliveTime(TimeUnit.SECONDS)); + } +} diff --git a/src/test/java/de/rub/nds/crawler/util/CancellableFutureTest.java b/src/test/java/de/rub/nds/crawler/util/CancellableFutureTest.java new file mode 100644 index 0000000..e5e7216 --- /dev/null +++ b/src/test/java/de/rub/nds/crawler/util/CancellableFutureTest.java @@ -0,0 +1,285 @@ +/* + * TLS-Crawler - A TLS scanning tool to perform large scale scans with the TLS-Scanner + * + * Copyright 2018-2023 Ruhr University Bochum, Paderborn University, and Hackmanit GmbH + * + * Licensed under Apache License, Version 2.0 + * http://www.apache.org/licenses/LICENSE-2.0.txt + */ +package de.rub.nds.crawler.util; + +import static org.junit.jupiter.api.Assertions.*; + +import java.util.concurrent.*; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import org.junit.jupiter.api.*; + +class CancellableFutureTest { + + private ExecutorService executor; + + @BeforeEach + void setUp() { + executor = Executors.newSingleThreadExecutor(); + } + + @AfterEach + void tearDown() { + executor.shutdown(); + } + + @Test + void testCallableConstructorAndRun() throws Exception { + // Given + String expectedResult = "test result"; + Callable callable = () -> expectedResult; + CancellableFuture future = new CancellableFuture<>(callable); + + // When + future.run(); + + // Then + assertTrue(future.isDone()); + assertEquals(expectedResult, future.get()); + } + + @Test + void testRunnableConstructorAndRun() throws Exception { + // Given + AtomicBoolean wasRun = new AtomicBoolean(false); + Runnable runnable = () -> wasRun.set(true); + String expectedResult = "fixed result"; + CancellableFuture future = new CancellableFuture<>(runnable, expectedResult); + + // When + future.run(); + + // Then + assertTrue(wasRun.get()); + assertTrue(future.isDone()); + assertEquals(expectedResult, future.get()); + } + + @Test + void testGetWithTimeout() throws Exception { + // Given + String expectedResult = "delayed result"; + Callable callable = + () -> { + Thread.sleep(100); + return expectedResult; + }; + CancellableFuture future = new CancellableFuture<>(callable); + + // When + executor.submit(future); + + // Then + assertEquals(expectedResult, future.get(1, TimeUnit.SECONDS)); + } + + @Test + void testGetWithTimeoutThrowsTimeoutException() { + // Given + Callable callable = + () -> { + Thread.sleep(2000); + return "too late"; + }; + CancellableFuture future = new CancellableFuture<>(callable); + + // When + executor.submit(future); + + // Then + assertThrows(TimeoutException.class, () -> future.get(100, TimeUnit.MILLISECONDS)); + } + + @Test + void testCancel() throws Exception { + // Given + CountDownLatch startLatch = new CountDownLatch(1); + CountDownLatch endLatch = new CountDownLatch(1); + AtomicBoolean wasInterrupted = new AtomicBoolean(false); + + Callable callable = + () -> { + startLatch.countDown(); + try { + Thread.sleep(5000); + return "should not complete"; + } catch (InterruptedException e) { + wasInterrupted.set(true); + endLatch.countDown(); + throw e; + } + }; + CancellableFuture future = new CancellableFuture<>(callable); + + // When + executor.submit(future); + startLatch.await(); // Wait for task to start + boolean cancelled = future.cancel(true); + + // Then + assertTrue(cancelled); + assertTrue(future.isCancelled()); + assertTrue(future.isDone()); + + // Wait for interrupt to be processed + assertTrue(endLatch.await(1, TimeUnit.SECONDS)); + assertTrue(wasInterrupted.get()); + } + + @Test + void testCancelWithoutInterrupt() { + // Given + Callable callable = + () -> { + Thread.sleep(5000); + return "should not complete"; + }; + CancellableFuture future = new CancellableFuture<>(callable); + + // When - cancel before running + boolean cancelled = future.cancel(false); + + // Then + assertTrue(cancelled); + assertTrue(future.isCancelled()); + assertTrue(future.isDone()); + } + + @Test + void testGetAfterCancel() throws Exception { + // Given + AtomicInteger callCount = new AtomicInteger(0); + String partialResult = "partial"; + + Callable callable = + () -> { + callCount.incrementAndGet(); + // Simulate some work that produces partial result + Thread.sleep(100); + return partialResult; + }; + + CancellableFuture future = new CancellableFuture<>(callable); + + // When + executor.submit(future); + Thread.sleep(50); // Let it start + future.cancel(true); + + // Then - get() should return the partial result after cancellation + try { + future.get(); + } catch (CancellationException e) { + // Expected when result is not yet available + } + } + + @Test + void testGetWithTimeoutAfterCancel() throws Exception { + // Given + String partialResult = "partial"; + CountDownLatch resultSetLatch = new CountDownLatch(1); + + Runnable runnable = + () -> { + try { + Thread.sleep(100); + resultSetLatch.countDown(); + } catch (InterruptedException e) { + resultSetLatch.countDown(); + } + }; + + CancellableFuture future = new CancellableFuture<>(runnable, partialResult); + + // When + executor.submit(future); + Thread.sleep(50); // Let it start + future.cancel(true); + resultSetLatch.await(); // Wait for result to be set + + // Then + try { + String result = future.get(1, TimeUnit.SECONDS); + assertEquals(partialResult, result); + } catch (CancellationException e) { + // This may happen if the timing is different + } + } + + @Test + void testIsDoneBeforeExecution() { + // Given + Callable callable = () -> "result"; + CancellableFuture future = new CancellableFuture<>(callable); + + // Then + assertFalse(future.isDone()); + } + + @Test + void testIsCancelledBeforeCancellation() { + // Given + Callable callable = () -> "result"; + CancellableFuture future = new CancellableFuture<>(callable); + + // Then + assertFalse(future.isCancelled()); + } + + @Test + void testExceptionPropagation() { + // Given + RuntimeException expectedException = new RuntimeException("Test exception"); + Callable callable = + () -> { + throw expectedException; + }; + CancellableFuture future = new CancellableFuture<>(callable); + + // When + future.run(); + + // Then + ExecutionException thrown = assertThrows(ExecutionException.class, future::get); + assertEquals(expectedException, thrown.getCause()); + } + + @Test + void testMultipleGetCalls() throws Exception { + // Given + String expectedResult = "result"; + Callable callable = () -> expectedResult; + CancellableFuture future = new CancellableFuture<>(callable); + + // When + future.run(); + + // Then - multiple gets should return the same result + assertEquals(expectedResult, future.get()); + assertEquals(expectedResult, future.get()); + assertEquals(expectedResult, future.get(1, TimeUnit.SECONDS)); + } + + @Test + void testCancelAfterCompletion() throws Exception { + // Given + Callable callable = () -> "result"; + CancellableFuture future = new CancellableFuture<>(callable); + + // When + future.run(); + boolean cancelled = future.cancel(true); + + // Then + assertFalse(cancelled); + assertFalse(future.isCancelled()); + assertTrue(future.isDone()); + } +}