1#!/usr/bin/env python2.7 2# Copyright 2015 gRPC authors. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15"""Starts a local DNS server for use in tests""" 16 17import argparse 18import sys 19import yaml 20import signal 21import os 22import threading 23import time 24 25import twisted 26import twisted.internet 27import twisted.internet.reactor 28import twisted.internet.threads 29import twisted.internet.defer 30import twisted.internet.protocol 31import twisted.names 32import twisted.names.client 33import twisted.names.dns 34import twisted.names.server 35from twisted.names import client, server, common, authority, dns 36import argparse 37import platform 38 39_SERVER_HEALTH_CHECK_RECORD_NAME = 'health-check-local-dns-server-is-alive.resolver-tests.grpctestingexp' # missing end '.' for twisted syntax 40_SERVER_HEALTH_CHECK_RECORD_DATA = '123.123.123.123' 41 42 43class NoFileAuthority(authority.FileAuthority): 44 45 def __init__(self, soa, records): 46 # skip FileAuthority 47 common.ResolverBase.__init__(self) 48 self.soa = soa 49 self.records = records 50 51 52def start_local_dns_server(args): 53 all_records = {} 54 55 def _push_record(name, r): 56 print('pushing record: |%s|' % name) 57 if all_records.get(name) is not None: 58 all_records[name].append(r) 59 return 60 all_records[name] = [r] 61 62 def _maybe_split_up_txt_data(name, txt_data, r_ttl): 63 start = 0 64 txt_data_list = [] 65 while len(txt_data[start:]) > 0: 66 next_read = len(txt_data[start:]) 67 if next_read > 255: 68 next_read = 255 69 txt_data_list.append(txt_data[start:start + next_read]) 70 start += next_read 71 _push_record(name, dns.Record_TXT(*txt_data_list, ttl=r_ttl)) 72 73 with open(args.records_config_path) as config: 74 test_records_config = yaml.load(config) 75 common_zone_name = test_records_config['resolver_tests_common_zone_name'] 76 for group in test_records_config['resolver_component_tests']: 77 for name in group['records'].keys(): 78 for record in group['records'][name]: 79 r_type = record['type'] 80 r_data = record['data'] 81 r_ttl = int(record['TTL']) 82 record_full_name = '%s.%s' % (name, common_zone_name) 83 assert record_full_name[-1] == '.' 84 record_full_name = record_full_name[:-1] 85 if r_type == 'A': 86 _push_record(record_full_name, 87 dns.Record_A(r_data, ttl=r_ttl)) 88 if r_type == 'AAAA': 89 _push_record(record_full_name, 90 dns.Record_AAAA(r_data, ttl=r_ttl)) 91 if r_type == 'SRV': 92 p, w, port, target = r_data.split(' ') 93 p = int(p) 94 w = int(w) 95 port = int(port) 96 target_full_name = '%s.%s' % (target, common_zone_name) 97 r_data = '%s %s %s %s' % (p, w, port, target_full_name) 98 _push_record( 99 record_full_name, 100 dns.Record_SRV(p, w, port, target_full_name, ttl=r_ttl)) 101 if r_type == 'TXT': 102 _maybe_split_up_txt_data(record_full_name, r_data, r_ttl) 103 # Add an optional IPv4 record is specified 104 if args.add_a_record: 105 extra_host, extra_host_ipv4 = args.add_a_record.split(':') 106 _push_record(extra_host, dns.Record_A(extra_host_ipv4, ttl=0)) 107 # Server health check record 108 _push_record(_SERVER_HEALTH_CHECK_RECORD_NAME, 109 dns.Record_A(_SERVER_HEALTH_CHECK_RECORD_DATA, ttl=0)) 110 soa_record = dns.Record_SOA(mname=common_zone_name) 111 test_domain_com = NoFileAuthority( 112 soa=(common_zone_name, soa_record), 113 records=all_records, 114 ) 115 server = twisted.names.server.DNSServerFactory( 116 authorities=[test_domain_com], verbose=2) 117 server.noisy = 2 118 twisted.internet.reactor.listenTCP(args.port, server) 119 dns_proto = twisted.names.dns.DNSDatagramProtocol(server) 120 dns_proto.noisy = 2 121 twisted.internet.reactor.listenUDP(args.port, dns_proto) 122 print('starting local dns server on 127.0.0.1:%s' % args.port) 123 print('starting twisted.internet.reactor') 124 twisted.internet.reactor.suggestThreadPoolSize(1) 125 twisted.internet.reactor.run() 126 127 128def _quit_on_signal(signum, _frame): 129 print('Received SIGNAL %d. Quitting with exit code 0' % signum) 130 twisted.internet.reactor.stop() 131 sys.stdout.flush() 132 sys.exit(0) 133 134 135def flush_stdout_loop(): 136 num_timeouts_so_far = 0 137 sleep_time = 1 138 # Prevent zombies. Tests that use this server are short-lived. 139 max_timeouts = 60 * 10 140 while num_timeouts_so_far < max_timeouts: 141 sys.stdout.flush() 142 time.sleep(sleep_time) 143 num_timeouts_so_far += 1 144 print('Process timeout reached, or cancelled. Exitting 0.') 145 os.kill(os.getpid(), signal.SIGTERM) 146 147 148def main(): 149 argp = argparse.ArgumentParser( 150 description='Local DNS Server for resolver tests') 151 argp.add_argument('-p', 152 '--port', 153 default=None, 154 type=int, 155 help='Port for DNS server to listen on for TCP and UDP.') 156 argp.add_argument( 157 '-r', 158 '--records_config_path', 159 default=None, 160 type=str, 161 help=('Directory of resolver_test_record_groups.yaml file. ' 162 'Defaults to path needed when the test is invoked as part ' 163 'of run_tests.py.')) 164 argp.add_argument( 165 '--add_a_record', 166 default=None, 167 type=str, 168 help=('Add an A record via the command line. Useful for when we ' 169 'need to serve a one-off A record that is under a ' 170 'different domain then the rest the records configured in ' 171 '--records_config_path (which all need to be under the ' 172 'same domain). Format: <name>:<ipv4 address>')) 173 args = argp.parse_args() 174 signal.signal(signal.SIGTERM, _quit_on_signal) 175 signal.signal(signal.SIGINT, _quit_on_signal) 176 output_flush_thread = threading.Thread(target=flush_stdout_loop) 177 output_flush_thread.setDaemon(True) 178 output_flush_thread.start() 179 start_local_dns_server(args) 180 181 182if __name__ == '__main__': 183 main() 184