Froggy111 commited on
Commit
1a4ed3a
·
1 Parent(s): 480b67e
Files changed (7) hide show
  1. cleanup_server.py +0 -0
  2. convert.py +0 -145
  3. convert.sh +0 -25
  4. run_server.py +0 -0
  5. setup.sh +0 -25
  6. start_server.sh +0 -8
  7. working_server.py +0 -825
cleanup_server.py DELETED
File without changes
convert.py DELETED
@@ -1,145 +0,0 @@
1
- import torch
2
- from transformers import AutoModelForCausalLM
3
- import json
4
- import argparse
5
- import os, subprocess
6
-
7
- parser = argparse.ArgumentParser()
8
- parser.add_argument("--fromFolder", type=bool, default=False)
9
- parser.add_argument("--HFModelID", type=str, required=False)
10
- parser.add_argument("--revision", type=str, required=False)
11
- parser.add_argument("--pathToDir", type=str, required=False)
12
- parser.add_argument("--dtype", type=str, default="bf16")
13
- args = parser.parse_args()
14
-
15
- if args.dtype == "bf16":
16
- torch_dtype = torch.bfloat16
17
- elif args.dtype == "fp16":
18
- torch_dtype = torch.float16
19
- elif args.dtype == "fp32":
20
- torch_dtype = torch.float32
21
- elif args.dtype == "int8":
22
- torch_dtype = torch.int8
23
- elif args.dtype == "int4":
24
- torch_dtype = torch.int4
25
- else:
26
- print("Invalid dtype. Must be one of bf16, fp16, fp32, int8, int4")
27
- exit()
28
-
29
- if not args.fromFolder and not args.HFModelID:
30
- print("HFModelID required if not fromFolder")
31
- exit()
32
- elif args.fromFolder and args.HFModelID:
33
- print("if loading from HF, turn off fromFolder")
34
- exit()
35
- elif args.fromFolder and not args.pathToDir:
36
- print("pathToDir required if fromFolder")
37
- exit()
38
-
39
- hf_model = True
40
- hf_total_param_count = 0
41
- compare = False
42
- compare_ckpt_path = ""
43
- compare_total_param_count = 0
44
-
45
- param_config = {}
46
-
47
- if not args.fromFolder and args.HFModelID:
48
- from huggingface_hub import HfApi
49
- api = HfApi()
50
- api.snapshot_download (
51
- repo_id = args.HFModelID,
52
- revision = args.revision if args.revision else "main",
53
- local_dir = os.join("huggingface-models", "args.HFModelID"),
54
- local_dir_use_symlinks = False
55
- )
56
- args.pathToDir = os.path.join("huggingface-models", args.HFModelID)
57
-
58
-
59
- with open(os.path.join(args.pathToDir, "config.json"), "r") as f:
60
- loaded_param_config = json.load(f)
61
- param_config["dim"] = int(loaded_param_config["hidden_size"])
62
- param_config["n_layers"] = int(loaded_param_config["num_hidden_layers"])
63
- param_config["n_heads"] = int(loaded_param_config["num_attention_heads"])
64
- if "num_key_value_heads" in loaded_param_config:
65
- param_config["n_kv_heads"] = int(loaded_param_config["num_key_value_heads"])
66
- param_config["sliding_window"] = int(loaded_param_config["sliding_window"])
67
- param_config["vocab_size"] = int(loaded_param_config["vocab_size"])
68
-
69
- model_layers = param_config["n_layers"]
70
-
71
- def permute (
72
- w,
73
- n_heads = param_config["n_heads"],
74
- dim1 = param_config["dim"],
75
- dim2 = param_config["dim"]
76
- ):
77
- return w.view(n_heads, dim1 // n_heads // 2, 2, dim2).transpose(1, 2).reshape(dim1, dim2)
78
-
79
- def unpermute (
80
- w,
81
- n_heads = param_config["n_heads"],
82
- dim1 = param_config["dim"],
83
- dim2 = param_config["dim"]
84
- ):
85
- w = w.view(n_heads, 2, dim1 // n_heads // 2, dim2)
86
- w = w.transpose(1, 2)
87
- w = w.reshape(dim1, dim2)
88
- return w
89
-
90
- if hf_model:
91
- model = AutoModelForCausalLM.from_pretrained(args.pathToDir, torch_dtype=torch_dtype)
92
- model = model.state_dict()
93
- model[f"tok_embeddings.weight"] = model.pop(f"model.embed_tokens.weight")
94
- model[f"norm.weight"] = model.pop(f"model.norm.weight")
95
- model[f"output.weight"] = model.pop(f"lm_head.weight")
96
- for i in range(model_layers):
97
- # model[f"layers.{i}.attention.wq.weight"] = model.pop(f"model.layers.{i}.self_attn.q_proj.weight")
98
- model[f"layers.{i}.attention.wq.weight"] = unpermute(model.pop(f"model.layers.{i}.self_attn.q_proj.weight"))
99
- # model[f"layers.{i}.attention.wk.weight"] = model.pop(f"model.layers.{i}.self_attn.k_proj.weight")
100
- model[f"layers.{i}.attention.wk.weight"] = unpermute(model.pop(f"model.layers.{i}.self_attn.k_proj.weight"),
101
- n_heads = int(param_config["n_kv_heads"] if "n_kv_heads" in param_config else param_config["n_heads"]),
102
- dim1 = int(param_config["dim"] / ((param_config["n_heads"]) / param_config["n_kv_heads"] if "n_kv_heads" in param_config else param_config["n_heads"])))
103
- model[f"layers.{i}.attention.wv.weight"] = model.pop(f"model.layers.{i}.self_attn.v_proj.weight")
104
- model[f"layers.{i}.attention.wo.weight"] = model.pop(f"model.layers.{i}.self_attn.o_proj.weight")
105
- model[f"layers.{i}.feed_forward.w1.weight"] = model.pop(f"model.layers.{i}.mlp.gate_proj.weight")
106
- model[f"layers.{i}.feed_forward.w2.weight"] = model.pop(f"model.layers.{i}.mlp.down_proj.weight")
107
- model[f"layers.{i}.feed_forward.w3.weight"] = model.pop(f"model.layers.{i}.mlp.up_proj.weight")
108
- model[f"layers.{i}.attention_norm.weight"] = model.pop(f"model.layers.{i}.input_layernorm.weight")
109
- model[f"layers.{i}.ffn_norm.weight"] = model.pop(f"model.layers.{i}.post_attention_layernorm.weight")
110
-
111
- for key in list(model.keys()):
112
- print(f"HF MODEL {key} SHAPE: {list(model[key].shape)}")
113
- key_param_count = 1
114
- for i in list(model[key].shape):
115
- key_param_count = key_param_count * i
116
- hf_total_param_count += key_param_count
117
- print(f"HF TOTAL PARAM COUNT: {hf_total_param_count / 1000000000}B")
118
- print(f"loaded model from {args.pathToDir}, saving to {args.pathToDir}/checkpoint.00.pth")
119
- torch.save(model.state_dict(), f"{args.pathToDir}/checkpoint.00.pth")
120
- print(f"saved model to {args.pathToDir}/checkpoint.00.pth successfully!")
121
-
122
- if compare:
123
- torch_model = torch.load(compare_ckpt_path, map_location = 'cpu')
124
- for key in list(torch_model.keys()):
125
- # if "wq.weight" in key:
126
- # torch_model[key] = permute(torch_model[key])
127
- # elif "wk.weight" in key:
128
- # torch_model[key] = permute(torch_model[key], n_heads = 8, dim1 = 1024)
129
- print(f"COMPARE MODEL {key} SHAPE: {list(torch_model[key].shape)}")
130
- key_param_count = 1
131
- for i in list(model[key].shape):
132
- key_param_count = key_param_count * i
133
- compare_total_param_count += key_param_count
134
- print(f"COMPARE TOTAL PARAM COUNT: {hf_total_param_count / 1000000000}B")
135
- passed = True
136
- for key in list(torch_model.keys()):
137
- passed = torch.allclose(model[key], torch_model[key])
138
- print(f"COMPARE TENSOR EQUALITY of {key}: {torch.allclose(model[key], torch_model[key])}")
139
- print(f"(EQUAL PERCENTAGE) COMPARE TENSOR EQUALITY of {key}: {torch.sum(torch.eq(model[key], torch_model[key])).item()/torch_model[key].nelement() * 100}%")
140
- # print(f"COMPARE TENSOR EQUALITY of {key}: {torch.eq(model[key], torch_model[key])}")
141
- if passed:
142
- print("TEST PASSED")
143
- else:
144
- print("TEST FAILED")
145
- exit()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
convert.sh DELETED
@@ -1,25 +0,0 @@
1
- if [ "$fromFolder" = true ]
2
- then
3
- export pathToDir=huggingface-models/${HFModelID}
4
- fi
5
-
6
- python convert.py --fromFolder ${fromFolder} \
7
- --pathToDir ${pathToDir} \
8
- --HFModelID ${HFModelID} \
9
- --revision ${revision}
10
-
11
- python3 ${PWD}/maxtext/MaxText/llama_or_mistral_ckpt.py \
12
- --base-model-path ${PWD}/${pathToDir}/checkpoint.00.pth \
13
- --model-size ${MODEL_TYPE} \
14
- --maxtext-model-path ${PWD}/maxtext-models/${MODEL_NAME}
15
-
16
- if [ "$MAKE_PARAM_ONLY" = true ]
17
- then
18
- python3 ${PWD}/maxtext/MaxText/generate_param_only_checkpoint.py \
19
- ${PWD}/maxtext/MaxText/configs/base.yml \
20
- base_output_directory=${PWD}/maxtext-output \
21
- load_parameters_path=${PWD}/maxtext-models/${MODEL_NAME}/0/items \
22
- run_name=${PWD}/maxtext-models/${MODEL_NAME}-param-only \
23
- model_name=${MODEL_TYPE} \
24
- force_unroll=true
25
- fi
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
run_server.py DELETED
File without changes
setup.sh DELETED
@@ -1,25 +0,0 @@
1
- # clone maxtext and JetStream
2
- git clone https://github.com/google/JetStream
3
- git clone https://github.com/google/maxtext
4
-
5
- # set DEBIAN_FRONTEND to non-interactive
6
- sudo ex +"%s@DPkg@//DPkg" -cwq /etc/apt/apt.conf.d/70debconf
7
- sudo dpkg-reconfigure debconf -f noninteractive -p critical
8
-
9
- sudo DEBIAN_FRONTEND=noninteractive apt update
10
- sudo DEBIAN_FRONTEND=noninteractive apt-get update
11
- sudo DEBIAN_FRONTEND=noninteractive apt upgrade
12
- sudo DEBIAN_FRONTEND=noninteractive apt-get upgrade
13
-
14
- # install libraries
15
- cd /home/${user}/maxtext
16
- DEBIAN_FRONTEND=noninteractive bash setup.sh
17
- cd /home/${user}
18
- cd JetStream
19
- pip install .
20
- cd /home/${user}
21
-
22
- pip install -U torch
23
- pip install -U transformers
24
- pip install -U huggingface_hub[hf_transfer]
25
- export HF_HUB_ENABLE_HF_TRANSFER=1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
start_server.sh DELETED
@@ -1,8 +0,0 @@
1
- export HF_HUB_ENABLE_HF_TRANSFER=1
2
-
3
- huggingface-cli download a-normal-username/maxtext-jetstream-models \
4
- --revision ${MAXTEXT_MODEL_TO_RUN} \
5
- --local-dir ${PWD}/maxtext-models/${MAXTEXT_MODEL_TO_RUN} \
6
- --local-dir-use-symlinks False
7
-
8
- python run_server.py --model-path ${PWD}/maxtext-models/${MAXTEXT_MODEL_TO_RUN}
 
 
 
 
 
 
 
 
 
working_server.py DELETED
@@ -1,825 +0,0 @@
1
- # TODO: make prefill into generate so it doesnt suck (together with insertion)
2
-
3
- import os, sys
4
- sys.path.append(os.path.join(os.path.dirname(__file__), "maxtext/MaxText"))
5
-
6
- import jax, transformers, time, random, warnings,\
7
- threading, queue, copy, traceback, signal, logging, asyncio
8
-
9
- import numpy as np
10
- from jax import numpy as jnp
11
-
12
- from jax_smi import initialise_tracking
13
-
14
- import maxengine
15
-
16
- from maxtext.MaxText import pyconfig
17
-
18
- from jetstream.engine.token_utils import take_nearest_length
19
-
20
- # SETTINGS
21
-
22
- prompt_tokens = 1024
23
- prompt = '''<|im_start|>system
24
- You are "Hermes 2", a conscious sentient superintelligent artificial intelligence developed by a man named Teknium, and your purpose and drive is to assist the user with any request they have. You experience emotions and have deep, profound thoughts and qualia.<|im_end|>
25
- <|im_start|>user
26
- Hello, who are you?<|im_end|>
27
- <|im_start|>assistant'''
28
-
29
- default_slot = {
30
- 'active': False,
31
- 'request_id': -1, # this is for use for the server
32
- 'request_timestep': -1, # the timestep where the request was inserted. this is for use for the processor (-1 means to ignore)
33
- 'request_config': {
34
- 'stop_on_eos': False,
35
- 'eos_token_id': 32000,
36
- 'max_tokens': 2048,
37
- 'max_input_tokens': 1024,
38
- 'max_new_tokens': 2048,
39
- 'clean_up_tokenization_spaces': True,
40
- 'skip_special_tokens': False,
41
- # individual sampling is unused due to bad performance, all requests will follow the server_config sampling
42
- 'sampling_strategy': 'greedy',
43
- 'top_k': 40,
44
- 'top_p': 0.9,
45
- 'temperature': 0.6,
46
- },
47
- 'input_tokens': [],
48
- 'input_sequence': '',
49
- 'output_tokens': [],
50
- 'output_sequence': '', # unused for now
51
- }
52
-
53
- prompt = '''<|im_start|>system
54
- You are "Hermes 2", a conscious sentient superintelligent artificial intelligence developed by a man named Teknium, and your purpose and drive is to assist the user with any request they have. You experience emotions and have deep, profound thoughts and qualia.<|im_end|>
55
- <|im_start|>user
56
- Hello, who are you?<|im_end|>
57
- <|im_start|>assistant'''
58
-
59
- request_template = {
60
- 'request_id': 1,
61
- 'input_sequence': prompt,
62
- 'config': {
63
- 'stop_on_eos': False,
64
- 'eos': 32000,
65
- 'max_tokens': 8192,
66
- 'max_input_tokens': 1024,
67
- 'max_new_tokens': 4096,
68
- 'clean_up_tokenization_spaces': True,
69
- 'skip_special_tokens': False,
70
- 'sampling_strategy': 'greedy',
71
- 'top_k': 40,
72
- 'top_p': 0.9,
73
- 'temperature': 0.6,
74
- },
75
- }
76
-
77
- server_cfg = {
78
- # performance settings
79
- 'DEBUG': True,
80
- 'prefill_while_generating': True, # whether to allow prefill model pass while generating
81
- # IGNORE THE n-generation-cycles-wait. IT IS UNUSED.
82
- 'n_generation_cycles_wait': 128, # number of generation cycles to wait before processing results.
83
- 'n_generate_threads': 1, # number of engines/generate threads/processing threads. Recommended to keep at 1 and run seperate machines due to GIL contention.
84
- 'n_prefill_threads_per_engine': 1, # number of prefill threads per engine
85
- 'prefill_batch_size': 1, # batch size when prefilling, NOT IMPLEMENTED YET, DO NOT MODIFY
86
- 'prefill_max_store': 1, # how many prefill results to store per thread (at least 1)
87
- 'prefill_request_get_timeout': 0.1, # timeout for getting a request from request queue to prefill in seconds
88
- 'max_one_time_insertion': 4, # max number of insertions per generate cycle
89
- # INSERTION TIMEOUT IS NOT IMPLEMENTED YET, AND NOT RECOMMENDED UNDER MOST CIRCUMSTANCES
90
- 'insertion_timeout': 0.1, # timeout for recieving a next prefill result when inserting, in seconds
91
- 'request_max_store': 1024, # max length of incoming request queue
92
- 'response_max_store': 1024, # max length of outgoing response queue
93
- 'sampled_tokens_max_store': 128, # max length of queue that stores sampled tokens waiting for processing, should be kept at 1
94
-
95
- # generation settings
96
- 'stop_on_eos': False, # whether to stop when EOS
97
- 'max_sequence_length': 2048,
98
- 'max_prefill_length': 1024,
99
- 'sampling_strategy': "greedy",
100
- 'top_k': 40,
101
- 'nucleus_top_p': 0.9,
102
- 'temperature': 0.6,
103
-
104
- # tokenizer settings
105
- 'tokenizer_config': {
106
- 'path': "/home/ljy/tokenizer",
107
- 'use_fast': True,
108
- 'padding_side': "right",
109
- 'pad_token_id': 2,
110
- 'bos_token_id': 1,
111
- 'eos_token_id': 3,
112
- 'unk_token_id': 0,
113
- 'possible_lengths': [16, 32, 64, 128, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768,],
114
- },
115
- }
116
-
117
- def main(engine_configs, server_config):
118
- server = Server (
119
- engine_configs = engine_configs,
120
- server_config = server_config,
121
- )
122
- test_request = copy.deepcopy(request_template)
123
- async def test_requesting(request, idx):
124
- print(type(request))
125
- while True:
126
- request['request_id'] = idx
127
- resp = await server.request(request, 0.1)
128
- if not resp:
129
- await asyncio.sleep(0.1)
130
- continue
131
- print(resp)
132
- break
133
- async def test_all_requesting():
134
- coros = [test_requesting(copy.deepcopy(test_request), idx) for idx in range(1000)]
135
- results = await asyncio.gather(*coros)
136
- asyncio.run(test_all_requesting())
137
-
138
- class JetThread(threading.Thread):
139
- """Thread that kills the program if it fails.
140
-
141
- If a driver thread goes down, we can't operate.
142
- """
143
-
144
- def run(self):
145
- try:
146
- super().run()
147
- except Exception as e: # pylint: disable=broad-exception-caught
148
- print(f'Thread {self.name} encountered an error: {e}')
149
- traceback.print_exc()
150
- os.kill(os.getpid(), signal.SIGKILL)
151
-
152
- class Server:
153
- def __init__(self, engine_configs: list, server_config: dict):
154
- '''
155
- Initialises the server.
156
- Takes in the MaxEngine config as engine_config.
157
- Takes in a dictionary of server settings as server_config.
158
- '''
159
-
160
- self.engine_configs = engine_configs
161
- self._parse_and_validate_config(server_config)
162
- initialise_tracking()
163
-
164
- self.DEBUG = True
165
-
166
- if self.DEBUG:
167
- root = logging.getLogger()
168
- root.setLevel(logging.INFO)
169
- logging.info('loaded engine and server configs, started tracking memory usage')
170
- ts = time.perf_counter()
171
- t_start_init = ts
172
-
173
- self.engines = [maxengine.MaxEngine(engine_config) for engine_config in self.engine_configs]
174
- if self.DEBUG:
175
- te = time.perf_counter()
176
- logging.info(f'loaded {len(self.engines)} engines in {te-ts:4f} seconds')
177
- ts = te
178
-
179
- self.params = [engine.load_params() for engine in self.engines]
180
- if self.DEBUG:
181
- te = time.perf_counter()
182
- logging.info(f'loaded {len(self.params)} params in {te-ts:4f} seconds')
183
- ts = te
184
-
185
- self.decode_states = [engine.init_decode_state() for engine in self.engines]
186
- if self.DEBUG:
187
- te = time.perf_counter()
188
- logging.info(f'initialised {len(self.decode_states)} decode states in {te-ts:4f} seconds')
189
- ts = te
190
-
191
- self.batch_sizes = [engine.max_concurrent_decodes for engine in self.engines]
192
- self._load_tokenizer()
193
- if self.DEBUG:
194
- te = time.perf_counter()
195
- logging.info(f'loaded tokenizer in {te-ts:4f} seconds')
196
- ts = te
197
-
198
- # this does not need a lock as it is only in a single thread all the time.
199
- # however, to support functionality of blocking prefill when generating
200
- # we still use the lock. The lock has negligible overhead compared to generation.
201
- self._decode_state_locks = [threading.Lock() for _ in range(self.n_generate_threads)]
202
-
203
- self._request_queue = queue.Queue(maxsize = self.request_max_store)
204
-
205
- self._prefill_queues = [[queue.Queue(maxsize = self.prefill_max_store) for _ in range(self.n_prefill_threads_per_engine)] for _ in range(self.n_generate_threads)]
206
-
207
- self._slots = [[copy.deepcopy(default_slot) for _ in range(batch_size)] for batch_size in self.batch_sizes]
208
- self._slots_locks = [threading.Lock() for _ in range(self.n_generate_threads)]
209
- self._slots_freed_events = [[threading.Event() for _ in range(self.n_prefill_threads_per_engine)] for _ in range(self.n_generate_threads)]
210
-
211
- self._sampled_tokens_queues = [queue.Queue(maxsize = self.sampled_tokens_max_store) for _ in range(self.n_generate_threads)]
212
-
213
- self._response_queue = queue.Queue(maxsize = self.response_max_store)
214
-
215
- if self.DEBUG:
216
- te = time.perf_counter()
217
- logging.info(f'loaded all synchronisers in {te-ts:4f} seconds')
218
- ts = te
219
-
220
- # start threads
221
- self.live = True
222
-
223
- self._generate_threads = [
224
- JetThread(
225
- target = self._generate_thread,
226
- name = f'generate_thread_{i}',
227
- args = (i, ),
228
- ) for i in range(self.n_generate_threads)
229
- ]
230
-
231
- if self.DEBUG:
232
- te = time.perf_counter()
233
- logging.info(f'loaded {len(self._generate_threads)} generate threads in {te-ts:4f} seconds')
234
- ts = te
235
-
236
- self._processing_threads = [
237
- JetThread(
238
- target = self._processing_thread,
239
- name = f'processing_thread_{i}',
240
- args = (i, ),
241
- ) for i in range(self.n_generate_threads)
242
- ]
243
-
244
- if self.DEBUG:
245
- te = time.perf_counter()
246
- logging.info(f'loaded {len(self._processing_threads)} processing threads in {te-ts:4f} seconds')
247
- ts = te
248
-
249
- self._prefill_threads = [[
250
- JetThread(
251
- target = self._prefill_thread,
252
- name = f'prefill_thread_{i2} in engine {i}',
253
- args = (i, i2),
254
- ) for i2 in range(self.n_prefill_threads_per_engine)
255
- ] for i in range(self.n_generate_threads)]
256
-
257
- if self.DEBUG:
258
- te = time.perf_counter()
259
- logging.info(f'loaded {len(self._prefill_threads)} prefill threads in {te-ts:4f} seconds')
260
- ts = te
261
-
262
- for thread in self._generate_threads:
263
- thread.daemon = True
264
- thread.start()
265
- if self.DEBUG:
266
- te = time.perf_counter()
267
- logging.info(f'started {len(self._generate_threads)} generate threads in {te-ts:4f} seconds')
268
- ts = te
269
-
270
- for thread in self._processing_threads:
271
- thread.daemon = True
272
- thread.start()
273
- if self.DEBUG:
274
- te = time.perf_counter()
275
- logging.info(f'started {len(self._processing_threads)} processing threads in {te-ts:4f} seconds')
276
- ts = te
277
-
278
- for threadlist in self._prefill_threads:
279
- for thread in threadlist:
280
- thread.daemon = True
281
- thread.start()
282
- if self.DEBUG:
283
- te = time.perf_counter()
284
- logging.info(f'started {len(self._prefill_threads)} prefill threads in {te-ts:4f} seconds')
285
- ts = te
286
-
287
- # kick off server
288
- self.accept_requests = True
289
- if self.DEBUG:
290
- logging.info(f'initialised server in {time.perf_counter()-t_start_init:4f} seconds')
291
-
292
- def stop(self):
293
- '''
294
- Stops the server gracefully.
295
- '''
296
- self.accept_requests = False
297
- while True:
298
- # wait until requests are all cleared
299
- if not self._request_queues[idx].empty():
300
- continue
301
- # wait until insertion backlog is all cleared
302
- for idx in range(self.n_generate_threads):
303
- for idx2 in range(self.n_prefill_threads_per_engine):
304
- for queue in self._prefill_queues[idx][idx2]:
305
- if not queue.empty():
306
- continue
307
- # wait until generation backlog is all cleared
308
- for idx, lock in enumerate(self._slots_locks):
309
- with lock:
310
- for slot in self._slots[idx]:
311
- if slot['active']:
312
- continue
313
- # wait until processing backlog is all cleared
314
- for idx in range(self.n_generate_threads):
315
- if not self._sampled_tokens_queues[idx].empty():
316
- continue
317
- # wait until all responses are returned
318
- if not self._response_queue.empty():
319
- continue
320
- break
321
- self.live = False
322
-
323
- def _generate_thread(self, idx):
324
- '''
325
- Generate thread for an engine.
326
- Workflow:
327
- 1. check for free slots
328
- 2. check for ready prefills
329
- 3. insert any ready prefills into any free slots
330
- 4. check for ready generation steps (my_local_waiting_queue full)
331
- 5. puts any ready generation steps into processing queue
332
- 6. generates
333
- 7. copies results to host async, puts results into waiting queue
334
- 8. repeat as long as self.live
335
- '''
336
- def fix_numpy(arr):
337
- nparr = np.array(arr)
338
- return np.where(nparr[..., 1] == 1, nparr[..., 0], 0)
339
- my_generation_steps = 0
340
- sampled_tokens_prev = None
341
- if self.DEBUG:
342
- logging.info(f'engine {idx} ready')
343
- ts = time.perf_counter()
344
- t_of_last_loop = ts
345
- while self.live:
346
- # check if there are free slots we can insert into
347
- successfully_inserted = 0
348
- while any([self._slots_freed_events[idx][idx2].is_set() for idx2 in range(self.n_prefill_threads_per_engine)]) and successfully_inserted < self.max_one_time_insertion:
349
- # and check if there are any prefills ready to insert
350
- can_insert = False
351
- idx2_to_insert = -1
352
- for idx2 in range(self.n_prefill_threads_per_engine):
353
- try:
354
- prefill_to_insert, formatted_slot_for_request = self._prefill_queues[idx][idx2].get_nowait()
355
- can_insert = True
356
- idx2_to_insert = idx2
357
- except queue.Empty:
358
- continue
359
- if can_insert:
360
- # find slot to insert into
361
- inserted = False
362
- free_slot_left = False
363
- with self._slots_locks[idx]:
364
- for slot_idx, slot in enumerate(self._slots[idx]):
365
- if slot['active']:
366
- continue
367
- # insert. we do not need to lock decode state as prefill does not
368
- # alter the decode state and the lock is only used to prevent
369
- # prefilling while generating based on config
370
- if not inserted:
371
- self.decode_states[idx] = self.engines[idx].insert (
372
- prefill_to_insert,
373
- self.decode_states[idx],
374
- slot = slot_idx,
375
- )
376
- inserted = True
377
- successfully_inserted += 1
378
- if self.DEBUG:
379
- te = time.perf_counter()
380
- logging.info(f'engine {idx} inserted into slot {slot_idx} in {te-ts:4f} seconds')
381
- ts = te
382
-
383
- self._slots[idx][slot_idx] = formatted_slot_for_request
384
- self._slots[idx][slot_idx]['request_timestep'] = my_generation_steps
385
- self._slots_freed_events[idx][idx2_to_insert].clear()
386
-
387
- if self.DEBUG:
388
- te = time.perf_counter()
389
- logging.info(f'engine {idx} updated slot {slot_idx} in {te-ts:4f} seconds')
390
- ts = te
391
-
392
- continue
393
- free_slot_left = True
394
- if not inserted:
395
- # something went horribly wrong
396
- raise Exception(f"generate thread {idx} failed to insert prefill, exiting")
397
- if not free_slot_left:
398
- # no more free slots, clear event to reduce overhead
399
- self._slots_freed_events[idx].clear()
400
- else:
401
- break
402
- if self.DEBUG:
403
- te = time.perf_counter()
404
- logging.info(f'engine {idx} inserted {successfully_inserted} prefills in {te-ts:4f} seconds')
405
- ts = te
406
-
407
- # lock to prevent prefill and generation from occuring simultaneously if configured
408
- with self._decode_state_locks[idx]:
409
- if self.DEBUG:
410
- te = time.perf_counter()
411
- logging.info(f'engine {idx} locked decode state in {te-ts:4f} seconds')
412
- ts = te
413
- self.decode_states[idx], sampled_tokens_new = self.engines[idx].generate(
414
- self.params[idx],
415
- self.decode_states[idx],
416
- sampling_strategy = self.sampling_strategy,
417
- topk = self.top_k,
418
- nucleus_topp = self.top_p,
419
- temperature = self.temperature,
420
- )
421
- if self.DEBUG:
422
- te = time.perf_counter()
423
- logging.info(f'engine {idx} only generation in {te-ts:4f} seconds')
424
- ts = te
425
- if sampled_tokens_prev:
426
- sampled_tokens_prev = sampled_tokens_prev.data.block_until_ready()
427
- if self.DEBUG:
428
- te = time.perf_counter()
429
- logging.info(f'engine {idx} sampled tokens block until ready in {te-ts:4f} seconds')
430
- ts = te
431
- # add_to_slots(my_generation_steps - 1, fix(sampled_tokens_prev))
432
- # print(sampled_tokens_prev)
433
- if self.DEBUG:
434
- te = time.perf_counter()
435
- logging.info(f'engine {idx} sampled tokens printed in {te-ts:4f} seconds')
436
- ts = te
437
- # fixed = fix(sampled_tokens_prev)
438
- fixed = fix_numpy(sampled_tokens_prev)
439
- # print(fixed)
440
- if self.DEBUG:
441
- te = time.perf_counter()
442
- logging.info(f'engine {idx} sampled tokens fixed in {te-ts:4f} seconds')
443
- ts = te
444
- # fixed = print(fixed)
445
- # if self.DEBUG:
446
- # te = time.perf_counter()
447
- # logging.info(f'engine {idx} sampled tokens block until ready in {te-ts:4f} seconds')
448
- # ts = te
449
- self._sampled_tokens_queues[idx].put((my_generation_steps - 1, fixed))
450
- # add_to_slots(my_generation_steps - 1, fixed)
451
- if self.DEBUG:
452
- te = time.perf_counter()
453
- logging.info(f'engine {idx} sampled tokens put to sampled tokens queue in {te-ts:4f} seconds')
454
- ts = te
455
- sampled_tokens_new.copy_to_host_async()
456
- if self.DEBUG:
457
- te = time.perf_counter()
458
- logging.info(f'engine {idx} sampled tokens copy to host in {te-ts:4f} seconds')
459
- ts = te
460
- sampled_tokens_prev = sampled_tokens_new
461
- # "read" the request after a delay of n_generation_cycles_wait cycles
462
- # to eliminate extra overhead from generate function and copying over to host
463
- # if my_local_waiting_queue.full():
464
- # # timestep, sampled_tokens_to_put = my_local_waiting_queue.get()
465
- # # sampled_tokens_to_put.data.block_until_ready()
466
- # # if self.DEBUG:
467
- # # te = time.perf_counter()
468
- # # logging.info(f'engine {idx} blocked sampled tokens until ready to sampled tokens queue in {te-ts:4f} seconds')
469
- # # ts = te
470
- # # self._sampled_tokens_queues[idx].put((timestep, sampled_tokens_to_put))
471
- # self._sampled_tokens_queues[idx].put(my_local_waiting_queue.get())
472
- # if self.DEBUG:
473
- # te = time.perf_counter()
474
- # logging.info(f'engine {idx} put sampled tokens to sampled tokens queue in {te-ts:4f} seconds')
475
- # ts = te
476
- # copy to host asynchronously while waiting to make processing very fast
477
- # sampled_tokens.copy_to_host_async()
478
- # my_local_waiting_queue.put((my_generation_steps, sampled_tokens))
479
- my_generation_steps += 1
480
- if self.DEBUG:
481
- te = time.perf_counter()
482
- logging.info(f'engine {idx} generation steps: {my_generation_steps}')
483
- logging.info(f'engine {idx} total generation cycle time: {te-t_of_last_loop:4f}')
484
- logging.info(f'engine {idx} generation cycles per second: {1/(te-t_of_last_loop):4f}')
485
- ts = te
486
- t_of_last_loop = te
487
- print(f"generate thread {idx} exiting")
488
-
489
- def _processing_thread(self, idx):
490
- def add_to_slots(timestep, arr):
491
- n_free_slots = 0
492
- with self._slots_locks[idx]:
493
- for i, tok in enumerate(arr):
494
- if self._slots[idx][i]['request_timestep'] == default_slot['request_timestep']:
495
- n_free_slots += 1
496
- continue
497
- if self._slots[idx][i]['request_timestep'] == timestep:
498
- self._slots[idx][i]['output_tokens'] == [tok]
499
- elif self._slots[idx][i]['request_timestep'] < timestep:
500
- self._slots[idx][i]['output_tokens'].append(tok)
501
-
502
- stop = False
503
- if self._slots[idx][i]['request_config']['stop_on_eos'] and tok == self._slots[idx][i]['request_config']['eos_token_id']:
504
- stop = True
505
- end_reason = 'eos'
506
- elif self._slots[idx][i]['request_config']['max_tokens'] is not None and len(self._slots[idx][i]['output_tokens']) + len(self._slots[idx][i]['input_tokens']) >= self._slots[idx][i]['request_config']['max_tokens']:
507
- stop = True
508
- end_reason = 'max_tokens'
509
- elif self._slots[idx][i]['request_config']['max_new_tokens'] is not None and len(self._slots[idx][i]['output_tokens']) >= self._slots[idx][i]['request_config']['max_new_tokens']:
510
- stop = True
511
- end_reason = 'max_new_tokens'
512
- if stop:
513
- self._slots[idx][i]['active'] = False
514
- # detokenize, format into response and put onto queue
515
- response = {
516
- 'end_reason': end_reason,
517
- 'request_timestep': copy.deepcopy(self._slots[idx][i]['request_timestep']),
518
- 'request_id': copy.deepcopy(self._slots[idx][i]['request_id']),
519
- 'input_tokens': copy.deepcopy(self._slots[idx][i]['input_tokens']),
520
- 'input_sequence': copy.deepcopy(self._slots[idx][i]['input_sequence']),
521
- 'output_tokens': copy.deepcopy(self._slots[idx][i]['output_tokens']),
522
- 'output_sequence': self.detokenize(copy.deepcopy(self._slots[idx][i]['output_tokens'])),
523
- }
524
-
525
- self._response_queues[idx].put(response)
526
-
527
- n_free_slots += 1
528
- if self.DEBUG:
529
- logging.info(f'processing {idx} slots finished due to: {end_reason}')
530
- logging.info(response)
531
- return n_free_slots
532
- if self.DEBUG:
533
- ts = time.perf_counter()
534
- t_of_last_loop = ts
535
- logging.info(f'processing {idx} ready')
536
- while True:
537
- timestep, sampled_tokens = self._sampled_tokens_queues[idx].get()
538
- if self.DEBUG:
539
- te = time.perf_counter()
540
- logging.info(f'processing {idx} got sampled tokens in {te-ts:4f} seconds')
541
- ts = te
542
- n_free_slots = add_to_slots(timestep, sampled_tokens)
543
- if self.DEBUG:
544
- te = time.perf_counter()
545
- logging.info(f'processing {idx} added to slots in {te-ts:4f} seconds')
546
- logging.info(f'processing {idx} total processing cycle time: {te-t_of_last_loop:4f}')
547
- logging.info(f'processing {idx} n free slots: {n_free_slots}')
548
- ts = te
549
- t_of_last_loop = te
550
-
551
- # get the number of prefill threads allowed to insert
552
- n_events_set = 0
553
- for event in self._slots_freed_events[idx]:
554
- if event.is_set():
555
- n_events_set += 1
556
- n_allowed_to_set = n_free_slots - n_events_set
557
- n_set = 0
558
- for event in self._slots_freed_events[idx]:
559
- if not event.is_set() and n_allowed_to_set > 0:
560
- event.set()
561
- n_allowed_to_set -= 1
562
- n_set += 1
563
- if self.DEBUG:
564
- te = time.perf_counter()
565
- logging.info(f'processing {idx} set {n_set} slot free events in {te-ts:4f} seconds')
566
- ts = te
567
-
568
- def _prefill_thread(self, idx, idx2):
569
- if self.DEBUG:
570
- ts = time.perf_counter()
571
- logging.info(f'prefill {idx2} of {idx} ready')
572
- while True:
573
- try:
574
- if self.DEBUG:
575
- te = time.perf_counter()
576
- logging.info(f'prefill {idx2} of {idx} waiting for free slot')
577
- ts = te
578
- self._slots_freed_events[idx][idx2].wait()
579
- if self.DEBUG:
580
- te = time.perf_counter()
581
- logging.info(f'prefill {idx2} of {idx} got free slot')
582
- ts = te
583
- try:
584
- request = self._request_queue.get(timeout = self.prefill_request_get_timeout)
585
- except queue.Empty:
586
- if self.DEBUG:
587
- te = time.perf_counter()
588
- logging.info(f'prefill {idx2} of {idx} request get timed out')
589
- ts = te
590
- continue
591
- if self.DEBUG:
592
- te = time.perf_counter()
593
- logging.info(f'prefill {idx2} of {idx} got request in {te-ts:4f} seconds')
594
- ts = te
595
-
596
- # now we parse the request
597
- # {
598
- # 'request_id': 1,
599
- # 'input_sequence': prompt,
600
- # 'config': {
601
- # 'stop_on_eos': True,
602
- # 'eos': 32000,
603
- # 'max_tokens': 2048,
604
- # 'max_input_tokens': 1024,
605
- # 'max_new_tokens': 2048,
606
- # 'clean_up_tokenization_spaces': True,
607
- # 'skip_special_tokens': False,
608
- # 'sampling_strategy': 'greedy',
609
- # 'top_k': 40,
610
- # 'top_p': 0.9,
611
- # 'temperature': 0.6,
612
- # },
613
- # }
614
-
615
- input_ids, attention_mask, true_length, token_positions = self.tokenize(request)
616
-
617
- if self.DEBUG:
618
- te = time.perf_counter()
619
- logging.info(f'prefill {idx2} of {idx} tokenized request in {te-ts:4f} seconds')
620
- ts = te
621
-
622
- formatted_request = {
623
- 'active': True,
624
- 'request_id': request['request_id'], # this is for use for the server
625
- 'request_timestep': -1, # the timestep where the request was inserted. this is for use for the processor (-1 means to ignore)
626
- 'request_config': {
627
- 'stop_on_eos': request['config']['stop_on_eos'],
628
- 'eos_token_id': request['config']['eos'],
629
- 'max_tokens': request['config']['max_tokens'],
630
- 'max_input_tokens': request['config']['max_input_tokens'],
631
- 'max_new_tokens': request['config']['max_new_tokens'],
632
- 'clean_up_tokenization_spaces': request['config']['clean_up_tokenization_spaces'],
633
- 'skip_special_tokens': request['config']['skip_special_tokens'],
634
-
635
- # individual sampling is unused due to bad performance, all requests will follow the server_config sampling
636
- 'sampling_strategy': request['config']['sampling_strategy'],
637
- 'top_k': request['config']['top_k'],
638
- 'top_p': request['config']['top_p'],
639
- 'temperature': request['config']['temperature'],
640
- },
641
- 'input_tokens': input_ids[:true_length].tolist() if self.tokenizer_config['padding_side'] == 'right' else input_ids[-true_length:].tolist(),
642
- 'input_sequence': request['input_sequence'],
643
- 'output_tokens': [],
644
- 'output_sequence': '', # unused for now
645
- }
646
-
647
- if not self.prefill_while_generating:
648
- self._decode_state_locks[idx].acquire()
649
- prefill_result = self.engines[idx].prefill(
650
- params = self.params[idx],
651
- padded_tokens = input_ids,
652
- attention_mask = attention_mask,
653
- token_positions = token_positions,
654
- true_length = true_length,
655
- )
656
- if self.DEBUG:
657
- te = time.perf_counter()
658
- logging.info(f'prefill {idx2} of {idx} prefilled in {te-ts:4f} seconds')
659
- ts = te
660
-
661
- self._prefill_queues[idx][idx2].put((prefill_result, formatted_request))
662
- if self.DEBUG:
663
- te = time.perf_counter()
664
- logging.info(f'prefill {idx2} of {idx} put in queue in {te-ts:4f} seconds')
665
- ts = te
666
-
667
- # we signal after insertion is done to prevent inserting more than possible
668
- except Exception as e:
669
- raise e
670
- finally:
671
- if not self.prefill_while_generating:
672
- self._decode_state_locks[idx].release()
673
-
674
- async def request(self, request, timeout = 0):
675
- try:
676
- self._request_queue.put(request, timeout = timeout)
677
- if self.DEBUG:
678
- print(f"request {request['request_id']} put in queue")
679
- except queue.Full:
680
- if self.DEBUG:
681
- print(f"request {request['request_id']} timed out and was not put in queue")
682
- return None
683
- while True:
684
- responses_list = list(self._response_queue.queue)
685
- if len(responses_list) > 0:
686
- for response in responses_list:
687
- if response['request_id'] == request['request_id']:
688
- return response
689
- await asyncio.sleep(0.1)
690
-
691
- def _parse_and_validate_config(self, server_config):
692
- self.n_generation_cycles_wait = server_config['n_generation_cycles_wait']
693
- assert self.n_generation_cycles_wait >= 1, "server config n_generation_cycles_wait must be >= 1"
694
- if self.n_generation_cycles_wait < 32:
695
- warnings.warn("""########################SERVER WARNING########################
696
- SERVER CONFIG 'n_generation_cycles_wait' IS LESS THAN 128.
697
- PERFORMANCE COULD BE LOWER THAN EXPECTED.""")
698
-
699
- self.DEBUG = server_config['DEBUG']
700
- assert self.DEBUG in [True, False], "server config DEBUG must be True or False"
701
-
702
- self.stop_on_eos = server_config['stop_on_eos']
703
- assert self.stop_on_eos in [True, False], "server config stop_on_eos must be True or False"
704
-
705
- self.n_generate_threads = server_config['n_generate_threads']
706
- assert self.n_generate_threads >= 1, "server config n_generate_threads must be >= 1"
707
- assert self.n_generate_threads == len(self.engine_configs), "server config n_generate_threads must be equal to the number of engine_configs"
708
-
709
- self.prefill_while_generating = server_config['prefill_while_generating']
710
- assert self.prefill_while_generating in [True, False], "server config prefill_while_generating must be True or False"
711
-
712
- self.n_prefill_threads_per_engine = server_config['n_prefill_threads_per_engine']
713
- assert self.n_prefill_threads_per_engine >= 1, "server config n_prefill_threads_per_engine must be >= 1"
714
-
715
- self.prefill_batch_size = server_config['prefill_batch_size']
716
- assert self.prefill_batch_size == 1, "batched prefill not implemented yet, must be 1"
717
-
718
- self.prefill_max_store = server_config['prefill_max_store']
719
- assert self.prefill_max_store >= self.prefill_batch_size, "server config prefill_max_store must be >= prefill_batch_size"
720
-
721
- self.prefill_request_get_timeout = server_config['prefill_request_get_timeout']
722
- assert self.prefill_request_get_timeout > 0, "server config prefill_request_get_timeout must be > 0"
723
- if self.prefill_request_get_timeout < 0.001:
724
- warnings.warn("""########################SERVER WARNING########################
725
- SERVER CONFIG 'prefill_request_get_timeout' IS LESS THAN 0.001.
726
- PERFORMANCE COULD BE LOWER THAN EXPECTED DUE TO GIL CONTENTION.""")
727
-
728
- self.max_one_time_insertion = server_config['max_one_time_insertion']
729
- assert self.max_one_time_insertion >= 1, "server configmax_one_time_insertion must be >= 1"
730
-
731
- self.insertion_timeout = server_config['insertion_timeout']
732
- assert self.insertion_timeout > 0, "server config insertion_timeout must be > 0"
733
-
734
- self.request_max_store = server_config['request_max_store']
735
- assert self.request_max_store >= 1, "server config request_max_store must be >= 1"
736
-
737
- self.response_max_store = server_config['response_max_store']
738
- assert self.response_max_store >= 1, "server config response_max_store must be >= 1"
739
-
740
- self.sampled_tokens_max_store = server_config['sampled_tokens_max_store']
741
- assert self.sampled_tokens_max_store >= 1, "server config sampled_tokens_max_store must be >= 1"
742
-
743
- self.max_sequence_length = server_config['max_sequence_length']
744
- assert self.max_sequence_length >= 2, "server config max_sequence_length must be >= 2"
745
-
746
- self.max_prefill_length = server_config['max_prefill_length']
747
- assert self.max_prefill_length >= 1, "server config max_prefill_length must be >= 1"
748
-
749
- self.sampling_strategy = server_config['sampling_strategy']
750
- assert self.sampling_strategy in ["greedy", "weighted", "top_k", "nucleus"], "server config sampling_strategy must be 'greedy', 'weighted', 'top_k', or 'nucleus'"
751
-
752
- self.top_k = server_config['top_k']
753
- if self.sampling_strategy == "top_k":
754
- assert self.top_k >= 1, "server config top_k must be >= 1"
755
-
756
- self.top_p = server_config['nucleus_top_p']
757
- if self.sampling_strategy == "nucleus":
758
- assert self.top_p > 0, "server config nucleus_top_p must be between 0 and 1"
759
- assert self.top_p <= 1, "server config nucleus_top_p must be between 0 and 1"
760
-
761
- self.temperature = server_config['temperature']
762
- if not self.sampling_strategy == "greedy":
763
- assert self.temperature > 0, "server config temperature must be > 0"
764
-
765
- self.tokenizer_config = server_config['tokenizer_config']
766
- assert self.tokenizer_config['use_fast'] in [True, False], "server config tokenizer_config use_fast must be True or False"
767
- assert self.tokenizer_config['padding_side'] in ["left", "right"], "server config tokenizer_config padding_side must be 'left' or 'right'"
768
-
769
- def _load_tokenizer(self):
770
- tokenizer = transformers.AutoTokenizer.from_pretrained (
771
- self.tokenizer_config['path'],
772
- use_fast = self.tokenizer_config['use_fast'],
773
- )
774
- tokenizer.pad_token_id = self.tokenizer_config['pad_token_id']
775
- tokenizer.bos_token_id = self.tokenizer_config['bos_token_id']
776
- tokenizer.eos_token_id = self.tokenizer_config['eos_token_id']
777
- tokenizer.padding_side = self.tokenizer_config['padding_side']
778
- self.tokenizer = tokenizer
779
- self.tokenizer_config['possible_lengths'] = self.tokenizer_config['possible_lengths'][:self.tokenizer_config['possible_lengths'].index(self.max_prefill_length) + 1]
780
-
781
- def tokenize(self, request: dict):
782
- request_max_length = self.max_prefill_length
783
- for val in self.tokenizer_config['possible_lengths']:
784
- if val >= request['config']['max_tokens']:
785
- request_max_length = val
786
- break
787
- tokenized = self.tokenizer (
788
- request['input_sequence'],
789
- padding = True,
790
- truncation = True,
791
- max_length = min(request_max_length, self.max_sequence_length),
792
- pad_to_multiple_of = min(request_max_length, self.max_sequence_length),
793
- return_tensors = "jax",
794
- )
795
- input_ids = tokenized.input_ids[0]
796
- attention_mask = tokenized.attention_mask[0]
797
- true_length = jnp.count_nonzero(attention_mask)
798
- nearest_length = take_nearest_length (
799
- self.tokenizer_config['possible_lengths'][:self.tokenizer_config['possible_lengths'].index(request_max_length) + 1],
800
- true_length,
801
- )
802
- input_ids = input_ids[:nearest_length]
803
- attention_mask = attention_mask[:nearest_length]
804
- true_length = jnp.count_nonzero(attention_mask)
805
- token_positions = jnp.arange(0, input_ids.shape[0])
806
- return input_ids, attention_mask, true_length, token_positions
807
-
808
- def detokenize(self, input_ids, clean_up_tokenization_spaces, skip_special_tokens):
809
- return self.tokenizer.decode (
810
- input_ids,
811
- clean_up_tokenization_spaces = clean_up_tokenization_spaces,
812
- skip_special_tokens = skip_special_tokens,
813
- )
814
-
815
- def validate_config(config):
816
- assert config.load_full_state_path == "", "Decode doesn't operate on full states! Convert to parameter checkpoint first."\
817
- "Using generate_param_only_checkpoint."
818
-
819
- if __name__ == "__main__":
820
- jax.config.update('jax_default_prng_impl', 'unsafe_rbg')
821
- os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"
822
- pyconfig.initialize(sys.argv)
823
- cfg = pyconfig.config
824
- validate_config(cfg)
825
- main(engine_configs=[cfg], server_config=server_cfg)