# BERT int8 inference problems for parameter b = 1 incuding all the
# relevant post-ops and data types propagation
# 2D problems have M = b * 384
# 4D problems have batch = b x 16
#              
# In total, there are 24 identical fragments in the topology:
#         ____|____ -----.
#        /    |    \     :
#      MM_1  MM_2  MM_3  :
#       |     |   /      :
#       |    MM_4 -------:
#        \   /           :
#         MM_5           :
#           |            :
#         MM_6 ----------`
#           |
#     Layer_norm ---.
#           |       :
#         MM_7      :
#           |       :
#         MM_8 -----`
#           |
#     Layer_norm

--reset
--skip-impl=ref
--dt=u8:s8:s8 --attr-scales=src:common:0.25+wei:common:0.5+dst:common:0.025 --bia_dt=bf16 --bia_mask=2
--attr-zero-points=dst:common:7+src:common:3
--stag=ab --wtag=any --dtag=ab
# MM_3 is the same
384x1024:1024x1024n"BERT:MM_1*48"

--reset
--skip-impl=ref
--dt=u8:s8:u8 --attr-scales=src:common:0.25+wei:common:0.5+dst:common:0.025 --bia_dt=bf16 --bia_mask=2
--attr-zero-points=dst:common:7+src:common:3
--stag=ab --wtag=any --dtag=ab
384x1024:1024x1024n"BERT:MM_2*24"

--reset
--skip-impl=ref
--dt=u8:s8:bf16 --stag=abcd --wtag=abdc --dtag=abcd
--attr-scales=src:common:0.25+wei:common:0.5+dst:common:0.025
#--attr-post-ops=add:bf16:13
--attr-zero-points=dst:common:7+wei:common:2+src:common:3
1x16x384x64:1x16x64x384n"BERT:MM_4*24"

--reset
--skip-impl=ref
--dt=u8:s8:u8 --stag=abcd --wtag=abcd --dtag=abcd
--attr-scales=src:common:0.25+wei:common:0.5+dst:common:0.025
--attr-zero-points=dst:common:7+wei:common:2+src:common:3
1x16x384x384:1x16x384x64n"BERT:MM_5*24"

--reset
--skip-impl=ref
--dt=u8:s8:bf16
--attr-scales=src:common:0.25+wei:common:0.5+dst:common:0.025 --bia_dt=bf16 --bia_mask=2
#--attr-post-ops=add:bf16:per_tensor
--attr-zero-points=dst:common:7+src:common:3
--stag=ab --wtag=any --dtag=ab
384x1024:1024x1024n"BERT:MM_6*24"

--reset
--skip-impl=ref
--dt=u8:s8:u8
--attr-scales=src:common:0.25+wei:common:0.5+dst:common:0.025 --bia_dt=bf16 --bia_mask=2
--attr-zero-points=dst:common:7+src:common:3
--stag=ab --wtag=any --dtag=ab
384x1024:1024x4096n"BERT:MM_7*24"

--reset
--skip-impl=ref
--dt=u8:s8:bf16
--attr-scales=src:common:0.25+wei:common:0.5+dst:common:0.025 --bia_dt=bf16 --bia_mask=2
--attr-zero-points=dst:common:7+src:common:3
#--attr-post-ops=add:bf16:per_tensor
--stag=ab --wtag=any --dtag=ab
384x4096:4096x1024n"BERT:MM_8*24"
