oneAPI Deep Neural Network Library (oneDNN)
Performance library for Deep Learning
2.1.2
dnnl.hpp
Go to the documentation of this file.
1 /*******************************************************************************
2 * Copyright 2016-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
19 
20 #ifndef ONEAPI_DNNL_DNNL_HPP
21 #define ONEAPI_DNNL_DNNL_HPP
22 
23 #include "oneapi/dnnl/dnnl_config.h"
24 
26 #include <algorithm>
27 #include <cstdlib>
28 #include <iterator>
29 #include <memory>
30 #include <string>
31 #include <vector>
32 #include <unordered_map>
33 
34 #include "oneapi/dnnl/dnnl.h"
35 
37 
38 // __cpp_exceptions is referred from
39 // https://gcc.gnu.org/onlinedocs/libstdc++/manual/using_exceptions.html
40 // gcc < 5 does not define __cpp_exceptions but __EXCEPTIONS,
41 // Microsoft C++ Compiler does not provide an option to disable exceptions
42 #ifndef DNNL_ENABLE_EXCEPTIONS
43 #if __cpp_exceptions || __EXCEPTIONS \
44  || (defined(_MSC_VER) && !defined(__clang__))
45 #define DNNL_ENABLE_EXCEPTIONS 1
46 #else
47 #define DNNL_ENABLE_EXCEPTIONS 0
48 #endif
49 #endif
50 
51 #if defined(__GNUC__) || defined(__clang__)
52 #define DNNL_TRAP() __builtin_trap()
53 #elif defined(__INTEL_COMPILER) || defined(_MSC_VER)
54 #define DNNL_TRAP() __debugbreak()
55 #else
56 #error "unknown compiler"
57 #endif
58 
59 #if DNNL_ENABLE_EXCEPTIONS
60 #define DNNL_THROW_ERROR(status, msg) throw error(status, msg)
61 #else
62 #include <cstdio>
63 #define DNNL_THROW_ERROR(status, msg) \
64  do { \
65  fputs(msg, stderr); \
66  DNNL_TRAP(); \
67  } while (0)
68 #endif
69 
72 
74 namespace dnnl {
75 
79 
84 struct error : public std::exception {
86  const char *message;
87 
92  error(dnnl_status_t status, const char *message)
93  : status(status), message(message) {}
94 
96  const char *what() const noexcept override { return message; }
97 
103  static void wrap_c_api(dnnl_status_t status, const char *message) {
104  if (status != dnnl_success) DNNL_THROW_ERROR(status, message);
105  }
106 };
107 
109 template <typename T>
110 void validate_container_size(const T &v, const char *error_message,
111  int min_size = 1, int max_size = -1) {
112  const int size = (int)v.size();
113  if (size < min_size || (max_size >= 0 && size > max_size))
114  DNNL_THROW_ERROR(dnnl_invalid_arguments, error_message);
115 }
117 
119 template <typename T>
120 struct handle_traits {};
121 
135 template <typename T, typename traits = handle_traits<T>>
136 struct handle {
137 private:
138  static dnnl_status_t dummy_destructor(T) { return dnnl_success; }
139  std::shared_ptr<typename std::remove_pointer<T>::type> data_ {0};
140 
141 protected:
142  bool operator==(const T other) const { return other == data_.get(); }
143  bool operator!=(const T other) const { return !(*this == other); }
144 
145 public:
153  handle() = default;
154 
156  handle(const handle<T, traits> &) = default;
160  handle(handle<T, traits> &&) = default;
163 
169  explicit handle(T t, bool weak = false) { reset(t, weak); }
170 
176  void reset(T t, bool weak = false) {
177  data_.reset(t, weak ? &dummy_destructor : traits::destructor);
178  }
179 
185  T get(bool allow_empty = false) const {
186  T result = data_.get();
187  if (allow_empty == false && result == nullptr)
188  DNNL_THROW_ERROR(
189  dnnl_invalid_arguments, "object is not initialized");
190  return result;
191  }
192 
197  explicit operator T() const { return get(true); }
198 
202  explicit operator bool() const { return get(true) != nullptr; }
203 
210  bool operator==(const handle<T, traits> &other) const {
211  return other.data_.get() == data_.get();
212  }
213 
220  bool operator!=(const handle &other) const { return !(*this == other); }
221 };
222 
224 template <>
225 struct handle_traits<dnnl_memory_t> {
226  static dnnl_status_t destructor(dnnl_memory_t p) {
227  return dnnl_memory_destroy(p);
228  }
229 };
230 
231 template <>
232 struct handle_traits<dnnl_primitive_desc_t> {
233  static dnnl_status_t destructor(dnnl_primitive_desc_t p) {
234  return dnnl_primitive_desc_destroy(p);
235  }
236 };
237 
238 template <>
239 struct handle_traits<dnnl_primitive_t> {
240  static dnnl_status_t destructor(dnnl_primitive_t p) {
241  return dnnl_primitive_destroy(p);
242  }
243 };
244 
245 template <>
246 struct handle_traits<dnnl_primitive_desc_iterator_t> {
247  static dnnl_status_t destructor(dnnl_primitive_desc_iterator_t p) {
249  }
250 };
252 
254 
255 struct stream;
256 struct memory;
257 struct primitive_desc;
258 
263 
267 
269 struct primitive : public handle<dnnl_primitive_t> {
271  enum class kind {
281  sum = dnnl_sum,
293  lrn = dnnl_lrn,
301  rnn = dnnl_rnn,
315  prelu = dnnl_prelu,
316  };
317 
318  using handle::handle;
319 
321  primitive() = default;
322 
327 
332 
338 
342  inline kind get_kind() const;
343 
356  void execute(const stream &astream,
357  const std::unordered_map<int, memory> &args) const;
358 };
359 
365  return static_cast<dnnl_primitive_kind_t>(akind);
366 }
367 
371  "could not get a primitive descriptor from a primitive");
372  return pd;
373 }
374 
377  // TODO (Roma): the code below is only needed because get_primitive_desc
378  // returns a C type.
381  pd, dnnl_query_primitive_kind, 0, (void *)&kind),
382  "could not get a primitive kind from a primitive descriptor");
383  return static_cast<dnnl::primitive::kind>(kind);
384 }
385 
387 
399 
401 enum class scratchpad_mode {
424 };
425 
431  return static_cast<dnnl_scratchpad_mode_t>(mode);
432 }
433 
435 enum class prop_kind {
459 };
460 
466  return static_cast<dnnl_prop_kind_t>(akind);
467 }
468 
470 enum class algorithm {
472  undef = dnnl_alg_kind_undef,
600 };
601 
606  return static_cast<dnnl_alg_kind_t>(aalgorithm);
607 }
608 
610 
613 
615 enum class normalization_flags : unsigned {
621 
630 
637 
643 };
644 
649  return static_cast<dnnl_normalization_flags_t>(flags);
650 }
651 
653 
656 
658 enum class rnn_flags : unsigned {
661 };
662 
667  return static_cast<dnnl_rnn_flags_t>(flags);
668 }
669 
670 #define DNNL_DEFINE_BITMASK_OPS(enum_name) \
671  inline enum_name operator|(enum_name lhs, enum_name rhs) { \
672  return static_cast<enum_name>( \
673  static_cast<unsigned>(lhs) | static_cast<unsigned>(rhs)); \
674  } \
675 \
676  inline enum_name operator&(enum_name lhs, enum_name rhs) { \
677  return static_cast<enum_name>( \
678  static_cast<unsigned>(lhs) & static_cast<unsigned>(rhs)); \
679  } \
680 \
681  inline enum_name operator^(enum_name lhs, enum_name rhs) { \
682  return static_cast<enum_name>( \
683  static_cast<unsigned>(lhs) ^ static_cast<unsigned>(rhs)); \
684  } \
685 \
686  inline enum_name &operator|=(enum_name &lhs, enum_name rhs) { \
687  lhs = static_cast<enum_name>( \
688  static_cast<unsigned>(lhs) | static_cast<unsigned>(rhs)); \
689  return lhs; \
690  } \
691 \
692  inline enum_name &operator&=(enum_name &lhs, enum_name rhs) { \
693  lhs = static_cast<enum_name>( \
694  static_cast<unsigned>(lhs) & static_cast<unsigned>(rhs)); \
695  return lhs; \
696  } \
697 \
698  inline enum_name &operator^=(enum_name &lhs, enum_name rhs) { \
699  lhs = static_cast<enum_name>( \
700  static_cast<unsigned>(lhs) ^ static_cast<unsigned>(rhs)); \
701  return lhs; \
702  } \
703 \
704  inline enum_name operator~(enum_name rhs) { \
705  return static_cast<enum_name>(~static_cast<unsigned>(rhs)); \
706  }
707 
708 DNNL_DEFINE_BITMASK_OPS(normalization_flags)
709 DNNL_DEFINE_BITMASK_OPS(rnn_flags)
710 
711 enum class rnn_direction {
725 };
726 
731  return static_cast<dnnl_rnn_direction_t>(dir);
732 }
733 
735 
738 
745 enum class query {
748 
753 
758 
765 
770 
775 
778 
781 
816 
835 };
836 
841  return static_cast<dnnl_query_t>(aquery);
842 }
843 
845 
847 
858 
860 template <>
861 struct handle_traits<dnnl_engine_t> {
862  static dnnl_status_t destructor(dnnl_engine_t p) {
863  return dnnl_engine_destroy(p);
864  }
865 };
867 
869 struct engine : public handle<dnnl_engine_t> {
870  friend struct primitive;
871  friend struct reorder;
872 
874  enum class kind {
878  cpu = dnnl_cpu,
880  gpu = dnnl_gpu,
881  };
882 
883  using handle::handle;
884 
887  engine() = default;
888 
893  static size_t get_count(kind akind) {
894  return dnnl_engine_get_count(convert_to_c(akind));
895  }
896 
902  engine(kind akind, size_t index) {
905  dnnl_engine_create(&engine, convert_to_c(akind), index),
906  "could not create an engine");
907  reset(engine);
908  }
909 
915  dnnl_engine_t c_engine;
918  dnnl::convert_to_c(dnnl::query::engine), 0, &c_engine),
919  "could not get an engine from a primitive_desc");
920  reset(c_engine, true);
921  }
922 
925  kind get_kind() const {
928  "could not get kind of an engine");
929  return static_cast<engine::kind>(kind);
930  }
931 
937  template <typename primitive_desc>
938  static engine query(const primitive_desc &pd) {
939  return query(pd, dnnl::query::engine);
940  }
941 
942 private:
943  static dnnl_engine_kind_t convert_to_c(kind akind) {
944  return static_cast<dnnl_engine_kind_t>(akind);
945  }
946 
947  template <typename primitive_desc>
948  static engine query(const primitive_desc &pd, dnnl::query what) {
949  dnnl_engine_t c_engine;
951  dnnl::convert_to_c(what), 0, &c_engine),
952  "could not get an engine from a primitive_desc");
953  return engine(c_engine, true);
954  }
955 };
956 
962  return static_cast<dnnl_engine_kind_t>(akind);
963 }
964 
966 
974 
976 template <>
977 struct handle_traits<dnnl_stream_t> {
978  static dnnl_status_t destructor(dnnl_stream_t p) {
979  return dnnl_stream_destroy(p);
980  }
981 };
983 
985 struct stream : public handle<dnnl_stream_t> {
986  using handle::handle;
987 
989  enum class flags : unsigned {
991  in_order = dnnl_stream_in_order,
996  };
997 
1000  stream() = default;
1001 
1007  stream(const engine &aengine, flags aflags = flags::default_flags) {
1010  static_cast<dnnl_stream_flags_t>(aflags)),
1011  "could not create a stream");
1012  reset(stream);
1013  }
1014 
1016  engine get_engine() const {
1017  dnnl_engine_t c_engine;
1019  "could not get an engine from a stream object");
1020  return engine(c_engine, true);
1021  }
1022 
1027  dnnl_stream_wait(get()), "could not wait on a stream");
1028  return *this;
1029  }
1030 };
1031 
1032 DNNL_DEFINE_BITMASK_OPS(stream::flags)
1033 
1034 
1101 
1108 struct memory : public handle<dnnl_memory_t> {
1109  using handle::handle;
1110 
1112  typedef dnnl_dim_t dim;
1115  typedef std::vector<dim> dims;
1116 
1123  template <typename T>
1124  static void validate_dims(const std::vector<T> &v, int min_size = 0) {
1125  validate_container_size(
1126  v, "dimensions are invalid", min_size, DNNL_MAX_NDIMS);
1127  }
1128 
1130  enum class data_type {
1134  f16 = dnnl_f16,
1137  bf16 = dnnl_bf16,
1139  f32 = dnnl_f32,
1141  s32 = dnnl_s32,
1143  s8 = dnnl_s8,
1145  u8 = dnnl_u8,
1146  };
1147 
1149  enum class format_kind {
1154  any = dnnl_format_kind_any,
1158  blocked = dnnl_blocked,
1160  wino = dnnl_format_kind_wino,
1162  packed = dnnl_format_kind_rnn_packed,
1163  };
1164 
1205  enum class format_tag {
1210  any = dnnl_format_tag_any,
1211 
1213  a = dnnl_a,
1214 
1216  ab = dnnl_ab,
1218  ba = dnnl_ba,
1219 
1221  abc = dnnl_abc,
1223  acb = dnnl_acb,
1225  bac = dnnl_bac,
1227  bca = dnnl_bca,
1229  cba = dnnl_cba,
1230 
1232  abcd = dnnl_abcd,
1234  abdc = dnnl_abdc,
1236  acdb = dnnl_acdb,
1238  bacd = dnnl_bacd,
1240  bcda = dnnl_bcda,
1242  cdba = dnnl_cdba,
1244  dcab = dnnl_dcab,
1245 
1247  abcde = dnnl_abcde,
1249  abdec = dnnl_abdec,
1251  acbde = dnnl_acbde,
1253  acdeb = dnnl_acdeb,
1255  bacde = dnnl_bacde,
1257  bcdea = dnnl_bcdea,
1259  cdeba = dnnl_cdeba,
1261  decab = dnnl_decab,
1263  abced = dnnl_abced,
1264 
1266  abcdef = dnnl_abcdef,
1268  abdfce = dnnl_abdfce,
1270  acbdef = dnnl_acbdef,
1272  abdefc = dnnl_abdefc,
1274  defcab = dnnl_defcab,
1276  abcdfe = dnnl_abcdfe,
1277 
1279  abcdefg = dnnl_abcdefg,
1281  abcdegf = dnnl_abcdegf,
1282 
1284  abcdefgh = dnnl_abcdefgh,
1286  abcdefhg = dnnl_abcdefhg,
1287 
1289  abcdefghi = dnnl_abcdefghi,
1291  abcdefgih = dnnl_abcdefgih,
1292 
1294  abcdefghij = dnnl_abcdefghij,
1296  abcdefghji = dnnl_abcdefghji,
1297 
1299  abcdefghijk = dnnl_abcdefghijk,
1301  abcdefghikj = dnnl_abcdefghikj,
1302 
1304  abcdefghijkl = dnnl_abcdefghijkl,
1306  abcdefghijlk = dnnl_abcdefghijlk,
1307 
1309  x = a,
1311  nc = ab,
1313  cn = ba,
1315  tn = ab,
1317  nt = ba,
1319  ncw = abc,
1321  nwc = acb,
1323  nchw = abcd,
1325  nhwc = acdb,
1327  chwn = bcda,
1329  ncdhw = abcde,
1331  ndhwc = acdeb,
1332 
1334  oi = ab,
1336  io = ba,
1338  oiw = abc,
1340  owi = acb,
1342  wio = cba,
1344  iwo = bca,
1346  oihw = abcd,
1348  hwio = cdba,
1350  ohwi = acdb,
1352  ihwo = bcda,
1354  iohw = bacd,
1356  oidhw = abcde,
1358  dhwio = cdeba,
1360  odhwi = acdeb,
1362  iodhw = bacde,
1364  idhwo = bcdea,
1365 
1367  goiw = abcd,
1369  gowi = abdc,
1371  wigo = dcab,
1373  gohwi = abdec,
1375  goihw = abcde,
1377  hwigo = decab,
1379  giohw = acbde,
1381  goidhw = abcdef,
1383  giodhw = acbdef,
1385  godhwi = abdefc,
1387  dhwigo = defcab,
1388 
1391  tnc = abc,
1394  ntc = bac,
1397  ldnc = abcd,
1405  ldigo = abcde,
1413  ldgoi = abdec,
1417  ldio = abcd,
1421  ldoi = abdc,
1429  ldgo = abcd,
1430 
1431  // Opaque blocked formats
1432 
1433  AB16b16a = dnnl_AB16b16a,
1434  AB16b32a = dnnl_AB16b32a,
1435  AB16b64a = dnnl_AB16b64a,
1436  AB8b16a2b = dnnl_AB8b16a2b,
1437  AB8b32a2b = dnnl_AB8b32a2b,
1438  AB8b64a2b = dnnl_AB8b64a2b,
1439  AB4b16a4b = dnnl_AB4b16a4b,
1440  AB4b32a4b = dnnl_AB4b32a4b,
1441  AB4b64a4b = dnnl_AB4b64a4b,
1442  AB16b16a4b = dnnl_AB16b16a4b,
1443  Abc16a = dnnl_Abc16a,
1444  ABc16a16b = dnnl_ABc16a16b,
1445  ABc4a4b = dnnl_ABc4a4b,
1446  aBc16b = dnnl_aBc16b,
1447  aBc32b = dnnl_aBc32b,
1448  ABc16b16a = dnnl_ABc16b16a,
1449  ABc16b32a = dnnl_ABc16b32a,
1450  ABc16b64a = dnnl_ABc16b64a,
1451  Abc4a = dnnl_Abc4a,
1452  aBc4b = dnnl_aBc4b,
1453  ABc4b16a4b = dnnl_ABc4b16a4b,
1454  ABc4b32a4b = dnnl_ABc4b32a4b,
1455  ABc4b64a4b = dnnl_ABc4b64a4b,
1456  ABc2b8a4b = dnnl_ABc2b8a4b,
1457  ABc16a16b2a = dnnl_ABc16a16b2a,
1458  ABc16b16a4b = dnnl_ABc16b16a4b,
1459  ABc16b16a2b = dnnl_ABc16b16a2b,
1460  ABc4b4a = dnnl_ABc4b4a,
1461  ABc8a16b2a = dnnl_ABc8a16b2a,
1462  ABc8a8b = dnnl_ABc8a8b,
1463  ABc8a4b = dnnl_ABc8a4b,
1464  aBc8b = dnnl_aBc8b,
1465  ABc8b16a2b = dnnl_ABc8b16a2b,
1466  ABc8b32a2b = dnnl_ABc8b32a2b,
1467  ABc8b64a2b = dnnl_ABc8b64a2b,
1468  ABc8b8a = dnnl_ABc8b8a,
1469  Abcd8a = dnnl_Abcd8a,
1470  Abcd16a = dnnl_Abcd16a,
1471  Abcd32a = dnnl_Abcd32a,
1472  ABcd16a16b = dnnl_ABcd16a16b,
1473  aBcd16b = dnnl_aBcd16b,
1474  aBcd32b = dnnl_aBcd32b,
1475  ABcd16b16a = dnnl_ABcd16b16a,
1476  ABcd16b32a = dnnl_ABcd16b32a,
1477  ABcd16b64a = dnnl_ABcd16b64a,
1478  aBCd16b16c = dnnl_aBCd16b16c,
1479  aBCd16c16b = dnnl_aBCd16c16b,
1480  Abcd4a = dnnl_Abcd4a,
1481  aBcd4b = dnnl_aBcd4b,
1482  ABcd4b16a4b = dnnl_ABcd4b16a4b,
1483  ABcd4b32a4b = dnnl_ABcd4b32a4b,
1484  ABcd4b64a4b = dnnl_ABcd4b64a4b,
1485  ABcd2b8a4b = dnnl_ABcd2b8a4b,
1486  ABcd4b4a = dnnl_ABcd4b4a,
1487  ABcd4a4b = dnnl_ABcd4a4b,
1488  aBCd4c16b4c = dnnl_aBCd4c16b4c,
1489  aBCd2c8b4c = dnnl_aBCd2c8b4c,
1490  ABcd16a16b2a = dnnl_ABcd16a16b2a,
1491  ABcd16b16a4b = dnnl_ABcd16b16a4b,
1492  ABcd16b16a2b = dnnl_ABcd16b16a2b,
1493  aBCd16b16c2b = dnnl_aBCd16b16c2b,
1494  aBCd16c16b4c = dnnl_aBCd16c16b4c,
1495  aBCd16c16b2c = dnnl_aBCd16c16b2c,
1496  aBCd4c4b = dnnl_aBCd4c4b,
1497  aBCd4b4c = dnnl_aBCd4b4c,
1498  ABcd8a16b2a = dnnl_ABcd8a16b2a,
1499  ABcd8a8b = dnnl_ABcd8a8b,
1500  ABcd8a4b = dnnl_ABcd8a4b,
1502  aBcd8b = dnnl_aBcd8b,
1503  ABcd8b16a2b = dnnl_ABcd8b16a2b,
1504  ABcd8b32a2b = dnnl_ABcd8b32a2b,
1505  ABcd8b64a2b = dnnl_ABcd8b64a2b,
1506  aBCd8b16c2b = dnnl_aBCd8b16c2b,
1508  ABcd8b8a = dnnl_ABcd8b8a,
1509  aBCd8b8c = dnnl_aBCd8b8c,
1510  aBCd8b4c = dnnl_aBCd8b4c,
1511  aBCd8c16b2c = dnnl_aBCd8c16b2c,
1512  aBCd8c8b = dnnl_aBCd8c8b,
1513  Abcde16a = dnnl_Abcde16a,
1514  Abcde32a = dnnl_Abcde32a,
1515  ABcde16a16b = dnnl_ABcde16a16b,
1516  aBcde16b = dnnl_aBcde16b,
1517  aBcde32b = dnnl_aBcde32b,
1518  ABcde16b16a = dnnl_ABcde16b16a,
1519  ABcde16b32a = dnnl_ABcde16b32a,
1520  ABcde16b64a = dnnl_ABcde16b64a,
1521  aBCde16b16c = dnnl_aBCde16b16c,
1522  aBCde16c16b = dnnl_aBCde16c16b,
1523  aBCde2c8b4c = dnnl_aBCde2c8b4c,
1524  Abcde4a = dnnl_Abcde4a,
1525  aBcde4b = dnnl_aBcde4b,
1526  ABcde4b4a = dnnl_ABcde4b4a,
1527  ABcde4a4b = dnnl_ABcde4a4b,
1528  aBCde4b4c = dnnl_aBCde4b4c,
1529  aBCde4c16b4c = dnnl_aBCde4c16b4c,
1530  aBCde16b16c2b = dnnl_aBCde16b16c2b,
1531  aBCde16c16b4c = dnnl_aBCde16c16b4c,
1532  aBCde16c16b2c = dnnl_aBCde16c16b2c,
1533  aBCdef16c16b2c = dnnl_aBCdef16c16b2c,
1534  aBCde4c4b = dnnl_aBCde4c4b,
1535  Abcde8a = dnnl_Abcde8a,
1536  ABcde8a8b = dnnl_ABcde8a8b,
1537  ABcde8a4b = dnnl_ABcde8a4b,
1538  aBcde8b = dnnl_aBcde8b,
1539  ABcde8b16a2b = dnnl_ABcde8b16a2b,
1540  ABcde8b32a2b = dnnl_ABcde8b32a2b,
1541  ABcde8b64a2b = dnnl_ABcde8b64a2b,
1542  ABcde4b16a4b = dnnl_ABcde4b16a4b,
1543  ABcde4b32a4b = dnnl_ABcde4b32a4b,
1544  ABcde4b64a4b = dnnl_ABcde4b64a4b,
1545  ABcde16b16a4b = dnnl_ABcde16b16a4b,
1546  ABcde16b16a2b = dnnl_ABcde16b16a2b,
1547  ABcde2b8a4b = dnnl_ABcde2b8a4b,
1548  aBCde8b16c2b = dnnl_aBCde8b16c2b,
1549  ABcde8b8a = dnnl_ABcde8b8a,
1550  aBCde8b8c = dnnl_aBCde8b8c,
1551  aBCde8b4c = dnnl_aBCde8b4c,
1552  ABcd4a8b8a4b = dnnl_ABcd4a8b8a4b,
1553  ABcd2a8b8a2b = dnnl_ABcd2a8b8a2b,
1554  aBCde4b8c8b4c = dnnl_aBCde4b8c8b4c,
1555  aBCde2b8c8b2c = dnnl_aBCde2b8c8b2c,
1556  aBCde8c16b2c = dnnl_aBCde8c16b2c,
1557  aBCde8c8b = dnnl_aBCde8c8b,
1558  aBcdef16b = dnnl_aBcdef16b,
1559  aBCdef16b16c = dnnl_aBCdef16b16c,
1560  aBCdef16c16b = dnnl_aBCdef16c16b,
1561  aBcdef4b = dnnl_aBcdef4b,
1562  aBCdef2c8b4c = dnnl_aBCdef2c8b4c,
1563  aBCdef4c4b = dnnl_aBCdef4c4b,
1564  aBCdef4b4c = dnnl_aBCdef4b4c,
1565  aBCdef8b8c = dnnl_aBCdef8b8c,
1566  aBCdef8b4c = dnnl_aBCdef8b4c,
1567  aBCdef8c16b2c = dnnl_aBCdef8c16b2c,
1568  aBCdef4c16b4c = dnnl_aBCdef4c16b4c,
1569  aBCdef8c8b = dnnl_aBCdef8c8b,
1570  aBdc16b = dnnl_aBdc16b,
1571  aBdc4b = dnnl_aBdc4b,
1572  aBdc8b = dnnl_aBdc8b,
1573  aBdec16b = dnnl_aBdec16b,
1574  aBdec4b = dnnl_aBdec4b,
1575  aBdec8b = dnnl_aBdec8b,
1576  aBdefc16b = dnnl_aBdefc16b,
1577  aCBdef16c16b = dnnl_aCBdef16c16b,
1578  aCBdef16b16c = dnnl_aCBdef16b16c,
1579  aBdefc4b = dnnl_aBdefc4b,
1580  aBdefc8b = dnnl_aBdefc8b,
1581  Acb16a = dnnl_Acb16a,
1582  Acb4a = dnnl_Acb4a,
1583  Acb8a = dnnl_Acb8a,
1584  aCBd16b16c = dnnl_aCBd16b16c,
1585  aCBd16c16b = dnnl_aCBd16c16b,
1586  aCBde16b16c = dnnl_aCBde16b16c,
1587  aCBde16c16b = dnnl_aCBde16c16b,
1588  Acdb16a = dnnl_Acdb16a,
1589  Acdb4a = dnnl_Acdb4a,
1590  Acdb8a = dnnl_Acdb8a,
1591  Acdeb16a = dnnl_Acdeb16a,
1592  Acdeb4a = dnnl_Acdeb4a,
1593  Acdeb8a = dnnl_Acdeb8a,
1594  BAc16a16b = dnnl_BAc16a16b,
1595  BAc16b16a = dnnl_BAc16b16a,
1596  BAcd16a16b = dnnl_BAcd16a16b,
1597  BAcd16b16a = dnnl_BAcd16b16a,
1598  ABcd32a32b = dnnl_ABcd32a32b,
1599  BAcde16b16a = dnnl_BAcde16b16a,
1600  BAcde16a16b = dnnl_BAcde16a16b,
1601  aBdec32b = dnnl_aBdec32b,
1602  Abcdef16a = dnnl_Abcdef16a,
1603  Abcdef32a = dnnl_Abcdef32a,
1604  Acdb32a = dnnl_Acdb32a,
1605  aBCd2b4c2b = dnnl_aBCd2b4c2b,
1606  aBCde2b4c2b = dnnl_aBCde2b4c2b,
1607  aBCdef2b4c2b = dnnl_aBCdef2b4c2b,
1608  aBCd2c4b2c = dnnl_aBCd2c4b2c,
1609  aBCde2c4b2c = dnnl_aBCde2c4b2c,
1610  aBCdef2c4b2c = dnnl_aBCdef2c4b2c,
1611  aBCd4b8c2b = dnnl_aBCd4b8c2b,
1612  aBCde4b8c2b = dnnl_aBCde4b8c2b,
1613  aBCdef4b8c2b = dnnl_aBCdef4b8c2b,
1614  aBCd4c8b2c = dnnl_aBCd4c8b2c,
1615  aBCde4c8b2c = dnnl_aBCde4c8b2c,
1616  aBCdef4c8b2c = dnnl_aBCdef4c8b2c,
1617  AB32a32b8a4b = dnnl_AB32a32b8a4b,
1618  AB32a32b8a2b = dnnl_AB32a32b8a2b,
1619  AB8a4b = dnnl_AB8a4b,
1620  AB8a2b = dnnl_AB8a2b,
1621  abDc32d = dnnl_abDc32d,
1622  abDC32d4c = dnnl_abDC32d4c,
1623  abdEc32e = dnnl_abdEc32e,
1624  abdEC32e2c = dnnl_abdEC32e2c,
1625  abdEC32e4c = dnnl_abdEC32e4c,
1626  aBCdef16c16b4c = dnnl_aBCdef16c16b4c,
1627  aBdC16b4c = dnnl_aBdC16b4c,
1628  aBdeC16b4c = dnnl_aBdeC16b4c,
1629  AcB16a4b = dnnl_AcB16a4b,
1630  AcdB16a2b = dnnl_AcdB16a2b,
1631  aBdefC16b4c = dnnl_aBdefC16b4c,
1632  AcdeB16a4b = dnnl_AcdeB16a4b,
1633 
1634  Acb32a = dnnl_Acb32a,
1635  AcB32a2b = dnnl_AcB32a2b,
1636  AcB32a4b = dnnl_AcB32a4b,
1637  Acb48a = dnnl_Acb48a,
1638  AcB48a2b = dnnl_AcB48a2b,
1639  AcB48a4b = dnnl_AcB48a4b,
1640  Acb64a = dnnl_Acb64a,
1641  AcB64a2b = dnnl_AcB64a2b,
1642  AcB64a4b = dnnl_AcB64a4b,
1643  cBa2b = dnnl_cBa2b,
1644  cBa4b = dnnl_cBa4b,
1645  aBdc32b = dnnl_aBdc32b,
1646  aBdC32b2c = dnnl_aBdC32b2c,
1647  aBdC32b4c = dnnl_aBdC32b4c,
1648  aBdc48b = dnnl_aBdc48b,
1649  aBdC48b2c = dnnl_aBdC48b2c,
1650  aBdC48b4c = dnnl_aBdC48b4c,
1651  aBdc64b = dnnl_aBdc64b,
1652  aBdC64b2c = dnnl_aBdC64b2c,
1653  aBdC64b4c = dnnl_aBdC64b4c,
1654  adcb = dnnl_adcb,
1655  adCb2c = dnnl_adCb2c,
1656  adCb4c = dnnl_adCb4c,
1657  AcdB32a2b = dnnl_AcdB32a2b,
1658  AcdB32a4b = dnnl_AcdB32a4b,
1659  Acdb48a = dnnl_Acdb48a,
1660  AcdB48a2b = dnnl_AcdB48a2b,
1661  AcdB48a4b = dnnl_AcdB48a4b,
1662  Acdb64a = dnnl_Acdb64a,
1663  AcdB64a2b = dnnl_AcdB64a2b,
1664  AcdB64a4b = dnnl_AcdB64a4b,
1665  cdBa2b = dnnl_cdBa2b,
1666  cdBa4b = dnnl_cdBa4b,
1667  aBdeC32b2c = dnnl_aBdeC32b2c,
1668  aBdeC32b4c = dnnl_aBdeC32b4c,
1669  aBdec48b = dnnl_aBdec48b,
1670  aBdeC48b2c = dnnl_aBdeC48b2c,
1671  aBdeC48b4c = dnnl_aBdeC48b4c,
1672  aBdec64b = dnnl_aBdec64b,
1673  aBdeC64b2c = dnnl_aBdeC64b2c,
1674  aBdeC64b4c = dnnl_aBdeC64b4c,
1675  adecb = dnnl_adecb,
1676  adeCb2c = dnnl_adeCb2c,
1677  adeCb4c = dnnl_adeCb4c,
1678  Acdeb32a = dnnl_Acdeb32a,
1679  AcdeB32a2b = dnnl_AcdeB32a2b,
1680  AcdeB32a4b = dnnl_AcdeB32a4b,
1681  Acdeb48a = dnnl_Acdeb48a,
1682  AcdeB48a2b = dnnl_AcdeB48a2b,
1683  AcdeB48a4b = dnnl_AcdeB48a4b,
1684  Acdeb64a = dnnl_Acdeb64a,
1685  AcdeB64a2b = dnnl_AcdeB64a2b,
1686  AcdeB64a4b = dnnl_AcdeB64a4b,
1687  cdeBa2b = dnnl_cdeBa2b,
1688  cdeBa4b = dnnl_cdeBa4b,
1689  aBdefc32b = dnnl_aBdefc32b,
1690  aBdefC32b2c = dnnl_aBdefC32b2c,
1691  aBdefC32b4c = dnnl_aBdefC32b4c,
1692  aBdefc48b = dnnl_aBdefc48b,
1693  aBdefC48b2c = dnnl_aBdefC48b2c,
1694  aBdefC48b4c = dnnl_aBdefC48b4c,
1695  aBdefc64b = dnnl_aBdefc64b,
1696  aBdefC64b2c = dnnl_aBdefC64b2c,
1697  aBdefC64b4c = dnnl_aBdefC64b4c,
1698  adefcb = dnnl_adefcb,
1699  adefCb2c = dnnl_adefCb2c,
1700  adefCb4c = dnnl_adefCb4c,
1701 
1702  format_tag_last = dnnl_format_tag_last,
1703 
1704  nCdhw16c = dnnl_nCdhw16c,
1705  nCdhw4c = dnnl_nCdhw4c,
1706  nCdhw8c = dnnl_nCdhw8c,
1707  nChw16c = dnnl_nChw16c,
1708  nChw4c = dnnl_nChw4c,
1709  nChw8c = dnnl_nChw8c,
1710  nCw16c = dnnl_nCw16c,
1711  nCw4c = dnnl_nCw4c,
1712  nCw8c = dnnl_nCw8c,
1713  NCw16n16c = dnnl_NCw16n16c,
1714  NChw16n16c = dnnl_NChw16n16c,
1715  NCdhw16n16c = dnnl_NCdhw16n16c,
1716  NCdhw32n32c = dnnl_NCdhw32n32c,
1717  NChw32n32c = dnnl_NChw32n32c,
1718  IOhw16i16o = dnnl_IOhw16i16o,
1719  OI16i16o = dnnl_OI16i16o,
1720  OI16i32o = dnnl_OI16i32o,
1721  OI16i64o = dnnl_OI16i64o,
1722  OI8i16o2i = dnnl_OI8i16o2i,
1723  OI8i32o2i = dnnl_OI8i32o2i,
1724  OI8i64o2i = dnnl_OI8i64o2i,
1725  OI4i16o4i = dnnl_OI4i16o4i,
1726  OI4i32o4i = dnnl_OI4i32o4i,
1727  OI4i64o4i = dnnl_OI4i64o4i,
1728  Ohwi32o = dnnl_Ohwi32o,
1729  IOdhw16i16o = dnnl_IOdhw16i16o,
1730  gIOhw16i16o = dnnl_gIOhw16i16o,
1731  gOhwi32o = dnnl_gOhwi32o,
1732  Goidhw16g = dnnl_Goidhw16g,
1733  IOw16o16i = dnnl_IOw16o16i,
1734  OIw16i16o = dnnl_OIw16i16o,
1735  OIw16i32o = dnnl_OIw16i32o,
1736  OIw16i64o = dnnl_OIw16i64o,
1737  IOw16i16o = dnnl_IOw16i16o,
1738  gIOw16i16o = dnnl_gIOw16i16o,
1739  OIw16o16i = dnnl_OIw16o16i,
1740  Oiw16o = dnnl_Oiw16o,
1741  OIw4i16o4i = dnnl_OIw4i16o4i,
1742  OIw4i32o4i = dnnl_OIw4i32o4i,
1743  OIw4i64o4i = dnnl_OIw4i64o4i,
1744  OIw2i8o4i = dnnl_OIw2i8o4i,
1745  OIw4i4o = dnnl_OIw4i4o,
1746  OIw4o4i = dnnl_OIw4o4i,
1747  Oiw4o = dnnl_Oiw4o,
1748  OIw8i16o2i = dnnl_OIw8i16o2i,
1749  OIw8i32o2i = dnnl_OIw8i32o2i,
1750  OIw8i64o2i = dnnl_OIw8i64o2i,
1751  OIw8i8o = dnnl_OIw8i8o,
1752  OIw8o16i2o = dnnl_OIw8o16i2o,
1753  OIw8o8i = dnnl_OIw8o8i,
1754  OIw8o4i = dnnl_OIw8o4i,
1755  OIw16i16o4i = dnnl_OIw16i16o4i,
1756  OIw16i16o2i = dnnl_OIw16i16o2i,
1757  OIw16o16i2o = dnnl_OIw16o16i2o,
1758  Owi16o = dnnl_Owi16o,
1759  OwI16o2i = dnnl_OwI16o2i,
1760  Owi4o = dnnl_Owi4o,
1761  Owi8o = dnnl_Owi8o,
1762  IOhw16o16i = dnnl_IOhw16o16i,
1763  Ohwi16o = dnnl_Ohwi16o,
1764  OhwI16o2i = dnnl_OhwI16o2i,
1765  Ohwi4o = dnnl_Ohwi4o,
1766  Ohwi8o = dnnl_Ohwi8o,
1767  OIhw16i16o = dnnl_OIhw16i16o,
1768  OIhw16i32o = dnnl_OIhw16i32o,
1769  OIhw16i64o = dnnl_OIhw16i64o,
1770  OIhw16o16i = dnnl_OIhw16o16i,
1771  Oihw16o = dnnl_Oihw16o,
1772  OIhw4i16o4i = dnnl_OIhw4i16o4i,
1773  OIhw4i32o4i = dnnl_OIhw4i32o4i,
1774  OIhw4i64o4i = dnnl_OIhw4i64o4i,
1775  OIhw4i4o = dnnl_OIhw4i4o,
1776  OIhw4o4i = dnnl_OIhw4o4i,
1777  Oihw4o = dnnl_Oihw4o,
1778  OIhw8i16o2i = dnnl_OIhw8i16o2i,
1779  OIhw8i32o2i = dnnl_OIhw8i32o2i,
1780  OIhw8i64o2i = dnnl_OIhw8i64o2i,
1781  OIhw8i8o = dnnl_OIhw8i8o,
1782  OIhw8o16i2o = dnnl_OIhw8o16i2o,
1783  OIhw8o8i = dnnl_OIhw8o8i,
1784  OIhw8o4i = dnnl_OIhw8o4i,
1785  OIhw2i8o4i = dnnl_OIhw2i8o4i,
1786  IOdhw16o16i = dnnl_IOdhw16o16i,
1787  Odhwi16o = dnnl_Odhwi16o,
1788  OdhwI16o2i = dnnl_OdhwI16o2i,
1789  Odhwi4o = dnnl_Odhwi4o,
1790  Odhwi8o = dnnl_Odhwi8o,
1791  OIdhw16i16o = dnnl_OIdhw16i16o,
1792  OIdhw16i32o = dnnl_OIdhw16i32o,
1793  OIdhw16i64o = dnnl_OIdhw16i64o,
1794  OIdhw16o16i = dnnl_OIdhw16o16i,
1795  Oidhw16o = dnnl_Oidhw16o,
1796  OIdhw4i4o = dnnl_OIdhw4i4o,
1797  OIdhw4o4i = dnnl_OIdhw4o4i,
1798  Oidhw4o = dnnl_Oidhw4o,
1799  OIdhw8i16o2i = dnnl_OIdhw8i16o2i,
1800  OIdhw8i32o2i = dnnl_OIdhw8i32o2i,
1801  OIdhw8i64o2i = dnnl_OIdhw8i64o2i,
1802  OIdhw4i16o4i = dnnl_OIdhw4i16o4i,
1803  OIdhw16i16o4i = dnnl_OIdhw16i16o4i,
1804  OIdhw4i32o4i = dnnl_OIdhw4i32o4i,
1805  OIdhw4i64o4i = dnnl_OIdhw4i64o4i,
1806  OIdhw2i8o4i = dnnl_OIdhw2i8o4i,
1807  OIdhw8i8o = dnnl_OIdhw8i8o,
1808  OIdhw8o8i = dnnl_OIdhw8o8i,
1809  OIdhw8o4i = dnnl_OIdhw8o4i,
1810  gIOw16o16i = dnnl_gIOw16o16i,
1811  gOIw16i16o = dnnl_gOIw16i16o,
1812  gOIw16o16i = dnnl_gOIw16o16i,
1813  gOiw16o = dnnl_gOiw16o,
1814  gOIw4i16o4i = dnnl_gOIw4i16o4i,
1815  gOIw2i8o4i = dnnl_gOIw2i8o4i,
1816  gOIw4i4o = dnnl_gOIw4i4o,
1817  gOIw4o4i = dnnl_gOIw4o4i,
1818  gOiw4o = dnnl_gOiw4o,
1819  gOIw8i16o2i = dnnl_gOIw8i16o2i,
1820  gOIw8i8o = dnnl_gOIw8i8o,
1821  gOIw8o16i2o = dnnl_gOIw8o16i2o,
1822  gOIw8o8i = dnnl_gOIw8o8i,
1823  gOIw8o4i = dnnl_gOIw8o4i,
1824  gOIw16i16o4i = dnnl_gOIw16i16o4i,
1825  gOIw16i16o2i = dnnl_gOIw16i16o2i,
1826  gOIw16o16i2o = dnnl_gOIw16o16i2o,
1827  gOwi16o = dnnl_gOwi16o,
1828  gOwI16o2i = dnnl_gOwI16o2i,
1829  gOwi4o = dnnl_gOwi4o,
1830  gOwi8o = dnnl_gOwi8o,
1831  Goiw8g = dnnl_Goiw8g,
1832  Goiw16g = dnnl_Goiw16g,
1833  gIOhw16o16i = dnnl_gIOhw16o16i,
1834  gOhwi16o = dnnl_gOhwi16o,
1835  gOhwI16o2i = dnnl_gOhwI16o2i,
1836  gOhwi4o = dnnl_gOhwi4o,
1837  gOhwi8o = dnnl_gOhwi8o,
1838  Goihw16g = dnnl_Goihw16g,
1839  gOIhw16i16o = dnnl_gOIhw16i16o,
1840  gOIhw16o16i = dnnl_gOIhw16o16i,
1841  gOihw16o = dnnl_gOihw16o,
1842  gOIhw4i16o4i = dnnl_gOIhw4i16o4i,
1843  gOIhw2i8o4i = dnnl_gOIhw2i8o4i,
1844  gOIhw4i4o = dnnl_gOIhw4i4o,
1845  gOIhw4o4i = dnnl_gOIhw4o4i,
1846  gOihw4o = dnnl_gOihw4o,
1847  Goihw8g = dnnl_Goihw8g,
1848  gOIhw8i16o2i = dnnl_gOIhw8i16o2i,
1849  gOIhw8i8o = dnnl_gOIhw8i8o,
1850  gOIhw8o16i2o = dnnl_gOIhw8o16i2o,
1851  OIw4o8i8o4i = dnnl_OIw4o8i8o4i,
1852  OIdhw4o8i8o4i = dnnl_OIdhw4o8i8o4i,
1853  OIhw4o8i8o4i = dnnl_OIhw4o8i8o4i,
1854  OIhw2o8i8o2i = dnnl_OIhw2o8i8o2i,
1855  gOIw4o8i8o4i = dnnl_gOIw4o8i8o4i,
1856  gOIdhw4o8i8o4i = dnnl_gOIdhw4o8i8o4i,
1857  gOIhw4o8i8o4i = dnnl_gOIhw4o8i8o4i,
1858  gOIhw2o8i8o2i = dnnl_gOIhw2o8i8o2i,
1859  OIhw16i16o4i = dnnl_OIhw16i16o4i,
1860  OIhw16i16o2i = dnnl_OIhw16i16o2i,
1861  OIhw16o16i2o = dnnl_OIhw16o16i2o,
1862  OIdhw16i16o2i = dnnl_OIdhw16i16o2i,
1863  gOIhw16i16o4i = dnnl_gOIhw16i16o4i,
1864  gOIhw16i16o2i = dnnl_gOIhw16i16o2i,
1865  gOIhw16o16i2o = dnnl_gOIhw16o16i2o,
1866  gOIhw8o8i = dnnl_gOIhw8o8i,
1867  gOIhw8o4i = dnnl_gOIhw8o4i,
1868  gIOdhw16i16o = dnnl_gIOdhw16i16o,
1869  gIOdhw16o16i = dnnl_gIOdhw16o16i,
1870  gOdhwi16o = dnnl_gOdhwi16o,
1871  gOdhwI16o2i = dnnl_gOdhwI16o2i,
1872  gOdhwi4o = dnnl_gOdhwi4o,
1873  gOdhwi8o = dnnl_gOdhwi8o,
1874  gOIdhw16i16o = dnnl_gOIdhw16i16o,
1875  gOIdhw16o16i = dnnl_gOIdhw16o16i,
1876  gOidhw16o = dnnl_gOidhw16o,
1877  gOIdhw4i4o = dnnl_gOIdhw4i4o,
1878  gOIdhw4o4i = dnnl_gOIdhw4o4i,
1879  gOidhw4o = dnnl_gOidhw4o,
1880  gOIdhw8i16o2i = dnnl_gOIdhw8i16o2i,
1881  gOIdhw4i16o4i = dnnl_gOIdhw4i16o4i,
1882  gOIdhw16i16o4i = dnnl_gOIdhw16i16o4i,
1883  gOIdhw16i16o2i = dnnl_gOIdhw16i16o2i,
1884  gOIdhw2i8o4i = dnnl_gOIdhw2i8o4i,
1885  gOIdhw8i8o = dnnl_gOIdhw8i8o,
1886  gOIdhw8o8i = dnnl_gOIdhw8o8i,
1887  gOIdhw8o4i = dnnl_gOIdhw8o4i,
1888  gOIw2i4o2i = dnnl_gOIw2i4o2i,
1889  gOIhw2i4o2i = dnnl_gOIhw2i4o2i,
1890  gOIdhw2i4o2i = dnnl_gOIdhw2i4o2i,
1891  gOIw2o4i2o = dnnl_gOIw2o4i2o,
1892  gOIhw2o4i2o = dnnl_gOIhw2o4i2o,
1893  gOIdhw2o4i2o = dnnl_gOIdhw2o4i2o,
1894  gOIw4i8o2i = dnnl_gOIw4i8o2i,
1895  gOIhw4i8o2i = dnnl_gOIhw4i8o2i,
1896  gOIdhw4i8o2i = dnnl_gOIdhw4i8o2i,
1897  gOIw4o8i2o = dnnl_gOIw4o8i2o,
1898  gOIhw4o8i2o = dnnl_gOIhw4o8i2o,
1899  gOIdhw4o8i2o = dnnl_gOIdhw4o8i2o,
1900  ldOi32o = abDc32d,
1901  ldOI32o4i = abDC32d4c,
1902  ldgOi32o = abdEc32e,
1903  ldgOI32o2i = abdEC32e2c,
1904  ldgOI32o4i = abdEC32e4c,
1905  OwI16o4i = dnnl_OwI16o4i,
1906  OhwI16o4i = dnnl_OhwI16o4i,
1907  gOwI16o4i = dnnl_gOwI16o4i,
1908  gOhwI16o4i = dnnl_gOhwI16o4i,
1909  OdhwI16o4i = dnnl_OdhwI16o4i,
1910  gOdhwI16o4i = dnnl_gOdhwI16o4i,
1911 
1912  Owi32o = dnnl_Owi32o,
1913  OwI32o2i = dnnl_OwI32o2i,
1914  OwI32o4i = dnnl_OwI32o4i,
1915  Owi48o = dnnl_Owi48o,
1916  OwI48o2i = dnnl_OwI48o2i,
1917  OwI48o4i = dnnl_OwI48o4i,
1918  Owi64o = dnnl_Owi64o,
1919  OwI64o2i = dnnl_OwI64o2i,
1920  OwI64o4i = dnnl_OwI64o4i,
1921  wIo2i = dnnl_wIo2i,
1922  wIo4i = dnnl_wIo4i,
1923  gOwi32o = dnnl_gOwi32o,
1924  gOwI32o2i = dnnl_gOwI32o2i,
1925  gOwI32o4i = dnnl_gOwI32o4i,
1926  gOwi48o = dnnl_gOwi48o,
1927  gOwI48o2i = dnnl_gOwI48o2i,
1928  gOwI48o4i = dnnl_gOwI48o4i,
1929  gOwi64o = dnnl_gOwi64o,
1930  gOwI64o2i = dnnl_gOwI64o2i,
1931  gOwI64o4i = dnnl_gOwI64o4i,
1932  gwio = dnnl_gwio,
1933  gwIo2i = dnnl_gwIo2i,
1934  gwIo4i = dnnl_gwIo4i,
1935  OhwI32o = dnnl_OhwI32o,
1936  OhwI32o2i = dnnl_OhwI32o2i,
1937  OhwI32o4i = dnnl_OhwI32o4i,
1938  Ohwi48o = dnnl_Ohwi48o,
1939  OhwI48o2i = dnnl_OhwI48o2i,
1940  OhwI48o4i = dnnl_OhwI48o4i,
1941  Ohwi64o = dnnl_Ohwi64o,
1942  OhwI64o2i = dnnl_OhwI64o2i,
1943  OhwI64o4i = dnnl_OhwI64o4i,
1944  hwIo2i = dnnl_hwIo2i,
1945  hwIo4i = dnnl_hwIo4i,
1946  gOhwI32o = dnnl_gOhwI32o,
1947  gOhwI32o2i = dnnl_gOhwI32o2i,
1948  gOhwI32o4i = dnnl_gOhwI32o4i,
1949  gOhwi48o = dnnl_gOhwi48o,
1950  gOhwI48o2i = dnnl_gOhwI48o2i,
1951  gOhwI48o4i = dnnl_gOhwI48o4i,
1952  gOhwi64o = dnnl_gOhwi64o,
1953  gOhwI64o2i = dnnl_gOhwI64o2i,
1954  gOhwI64o4i = dnnl_gOhwI64o4i,
1955  ghwio = dnnl_ghwio,
1956  ghwIo2i = dnnl_ghwIo2i,
1957  ghwIo4i = dnnl_ghwIo4i,
1958  Odhwi32o = dnnl_Odhwi32o,
1959  OdhwI32o2i = dnnl_OdhwI32o2i,
1960  OdhwI32o4i = dnnl_OdhwI32o4i,
1961  Odhwi48o = dnnl_Odhwi48o,
1962  OdhwI48o2i = dnnl_OdhwI48o2i,
1963  OdhwI48o4i = dnnl_OdhwI48o4i,
1964  Odhwi64o = dnnl_Odhwi64o,
1965  OdhwI64o2i = dnnl_OdhwI64o2i,
1966  OdhwI64o4i = dnnl_OdhwI64o4i,
1967  dhwIo2i = dnnl_dhwIo2i,
1968  dhwIo4i = dnnl_dhwIo4i,
1969  gOdhwi32o = dnnl_gOdhwi32o,
1970  gOdhwI32o2i = dnnl_gOdhwI32o2i,
1971  gOdhwI32o4i = dnnl_gOdhwI32o4i,
1972  gOdhwi48o = dnnl_gOdhwi48o,
1973  gOdhwI48o2i = dnnl_gOdhwI48o2i,
1974  gOdhwI48o4i = dnnl_gOdhwI48o4i,
1975  gOdhwi64o = dnnl_gOdhwi64o,
1976  gOdhwI64o2i = dnnl_gOdhwI64o2i,
1977  gOdhwI64o4i = dnnl_gOdhwI64o4i,
1978  gdhwio = dnnl_gdhwio,
1979  gdhwIo2i = dnnl_gdhwIo2i,
1980  gdhwIo4i = dnnl_gdhwIo4i,
1981  };
1982 
1984  struct desc {
1985  friend struct memory;
1988 
1991  desc() : data() {}
1992 
2008  desc(const dims &adims, data_type adata_type, format_tag aformat_tag,
2009  bool allow_empty = false)
2010  : data() {
2011  validate_dims(adims);
2013  (int)adims.size(), adims.data(), convert_to_c(adata_type),
2014  convert_to_c(aformat_tag));
2015  if (!allow_empty)
2017  "could not construct a memory descriptor using a "
2018  "format tag");
2019  }
2020 
2036  desc(const dims &adims, data_type adata_type, const dims &strides,
2037  bool allow_empty = false)
2038  : data() {
2039  validate_dims(adims);
2040  if (!strides.empty()) validate_dims(strides, (int)adims.size());
2042  (int)adims.size(), adims.data(), convert_to_c(adata_type),
2043  strides.empty() ? nullptr : &strides[0]);
2044  if (!allow_empty)
2046  "could not construct a memory descriptor using "
2047  "strides");
2048  }
2049 
2053  desc(const dnnl_memory_desc_t &data) : data(data) {}
2054 
2057  //
2066  desc submemory_desc(const dims &adims, const dims &offsets,
2067  bool allow_empty = false) const {
2068  validate_dims(adims, data.ndims);
2069  validate_dims(offsets, data.ndims);
2072  &sub_md, &data, adims.data(), offsets.data());
2073  if (!allow_empty)
2074  error::wrap_c_api(status, "could not construct a sub-memory");
2075  return desc(sub_md);
2076  }
2077 
2122  desc reshape(const dims &adims, bool allow_empty = false) const {
2123  if (data.ndims) validate_dims(adims, 1);
2126  &out_md, &data, (int)adims.size(), adims.data());
2127  if (!allow_empty)
2129  status, "could not reshape a memory descriptor");
2130  return desc(out_md);
2131  }
2132 
2170  desc permute_axes(const std::vector<int> &permutation,
2171  bool allow_empty = false) const {
2172  validate_dims(permutation, data.ndims);
2175  &out_md, &data, permutation.data());
2176  if (!allow_empty)
2178  "could not permute axes of a memory descriptor");
2179  return desc(out_md);
2180  }
2181 
2186  memory::dims dims() const {
2187  return memory::dims(data.dims, data.dims + data.ndims);
2188  }
2189 
2193  return static_cast<memory::data_type>(data.data_type);
2194  }
2195 
2200  size_t get_size() const { return dnnl_memory_desc_get_size(&data); }
2201 
2205  bool is_zero() const { return data.ndims == 0; }
2206 
2211  bool operator==(const desc &other) const {
2212  return dnnl_memory_desc_equal(&data, &other.data) != 0;
2213  }
2214 
2219  bool operator!=(const desc &other) const { return !operator==(other); }
2220 
2224  explicit operator bool() const { return data.ndims != 0; }
2225  };
2226 
2231  memory() = default;
2232 
2252  memory(const desc &md, const engine &aengine, void *handle) {
2253  dnnl_memory_t result;
2255  dnnl_memory_create(&result, &md.data, aengine.get(), handle),
2256  "could not create a memory object");
2257  reset(result);
2258  }
2259 
2266  memory(const desc &md, const engine &aengine)
2267  : memory(md, aengine, DNNL_MEMORY_ALLOCATE) {}
2268 
2270  desc get_desc() const {
2271  const dnnl_memory_desc_t *cdesc;
2273  "could not get a memory descriptor from a memory object");
2274  return desc(*cdesc);
2275  }
2276 
2278  engine get_engine() const {
2279  dnnl_engine_t c_engine;
2280  error::wrap_c_api(dnnl_memory_get_engine(get(), &c_engine),
2281  "could not get an engine from a memory object");
2282  return engine(c_engine, true);
2283  }
2284 
2289  void *get_data_handle() const {
2290  void *handle;
2292  "could not get a native handle from a memory object");
2293  return handle;
2294  }
2295 
2324  void set_data_handle(void *handle, const stream &astream) const {
2326  get(), handle, astream.get(true)),
2327  "could not set native handle of a memory object");
2328  }
2329 
2340  void set_data_handle(void *handle) const {
2342  dnnl_memory_set_data_handle_v2(get(), handle, nullptr),
2343  "could not set native handle of a memory object");
2344  }
2345 
2367  template <typename T = void>
2368  T *map_data() const {
2369  void *mapped_ptr;
2370  error::wrap_c_api(dnnl_memory_map_data(get(), &mapped_ptr),
2371  "could not map memory object data");
2372  return static_cast<T *>(mapped_ptr);
2373  }
2374 
2385  void unmap_data(void *mapped_ptr) const {
2386  error::wrap_c_api(dnnl_memory_unmap_data(get(), mapped_ptr),
2387  "could not unmap memory object data");
2388  }
2389 
2390  static dnnl_data_type_t convert_to_c(data_type adata_type) {
2391  return static_cast<dnnl_data_type_t>(adata_type);
2392  }
2393  static dnnl_format_tag_t convert_to_c(format_tag format) {
2394  return static_cast<dnnl_format_tag_t>(format);
2395  }
2396 };
2397 
2398 inline bool operator==(dnnl_data_type_t a, memory::data_type b) {
2399  return a == memory::convert_to_c(b);
2400 }
2401 inline bool operator!=(dnnl_data_type_t a, memory::data_type b) {
2402  return !(a == b);
2403 }
2404 inline bool operator==(memory::data_type a, dnnl_data_type_t b) {
2405  return b == a;
2406 }
2407 inline bool operator!=(memory::data_type a, dnnl_data_type_t b) {
2408  return !(a == b);
2409 }
2410 
2411 inline bool operator==(dnnl_format_tag_t a, memory::format_tag b) {
2412  return a == memory::convert_to_c(b);
2413 }
2414 inline bool operator!=(dnnl_format_tag_t a, memory::format_tag b) {
2415  return !(a == b);
2416 }
2417 inline bool operator==(memory::format_tag a, dnnl_format_tag_t b) {
2418  return b == a;
2419 }
2420 inline bool operator!=(memory::format_tag a, dnnl_format_tag_t b) {
2421  return !(a == b);
2422 }
2423 
2425 
2433 
2435 template <>
2436 struct handle_traits<dnnl_post_ops_t> {
2437  static dnnl_status_t destructor(dnnl_post_ops_t p) {
2438  return dnnl_post_ops_destroy(p);
2439  }
2440 };
2442 
2450 struct post_ops : public handle<dnnl_post_ops_t> {
2452 
2455  dnnl_post_ops_t result;
2457  dnnl_post_ops_create(&result), "could not create post-ops");
2458  reset(result);
2459  }
2460 
2462  int len() const { return dnnl_post_ops_len(get()); }
2463 
2467  primitive::kind kind(int index) const {
2469  "post-ops index is out of range");
2470  return static_cast<primitive::kind>(
2471  dnnl_post_ops_get_kind(get(), index));
2472  }
2473 
2502  void append_sum(float scale = 1.f,
2504  if (data_type == memory::data_type::undef)
2506  "could not append a sum post-op");
2507  else
2509  memory::convert_to_c(data_type)),
2510  "could not append a sum post-op");
2511  }
2512 
2517  void get_params_sum(int index, float &scale) const {
2519  "could not get parameters of a sum post-op");
2520  }
2521 
2528  int index, float &scale, memory::data_type &data_type) const {
2529  dnnl_data_type_t c_data_type;
2531  get(), index, &scale, &c_data_type),
2532  "could not get parameters of a sum post-op");
2533  data_type = static_cast<memory::data_type>(c_data_type);
2534  }
2535 
2550  float scale, algorithm aalgorithm, float alpha, float beta) {
2552  convert_to_c(aalgorithm), alpha, beta),
2553  "could not append an elementwise post-op");
2554  }
2555 
2563  void get_params_eltwise(int index, float &scale, algorithm &aalgorithm,
2564  float &alpha, float &beta) const {
2565  dnnl_alg_kind_t c_alg;
2567  get(), index, &scale, &c_alg, &alpha, &beta),
2568  "could not get parameters of an elementwise post-op");
2569  aalgorithm = static_cast<dnnl::algorithm>(c_alg);
2570  }
2571 
2600  void append_dw_k3s1p1(memory::data_type weights_data_type,
2601  memory::data_type bias_data_type, memory::data_type dst_data_type,
2602  int mask, const std::vector<float> &scales) {
2603 
2605  memory::convert_to_c(weights_data_type),
2606  memory::convert_to_c(bias_data_type),
2607  memory::convert_to_c(dst_data_type),
2608  scales.size(), mask, &scales[0]),
2609  "could not append depthwise post-op");
2610  }
2611 
2626  void get_params_dw_k3s1p1(int index, memory::data_type &weights_data_type,
2627  memory::data_type &bias_data_type, memory::data_type &dst_data_type,
2628  int &mask, std::vector<float> &scales) const {
2629 
2630  dnnl_data_type_t c_weights_data_type;
2631  dnnl_data_type_t c_bias_data_type;
2632  dnnl_data_type_t c_dst_data_type;
2633  dnnl_dim_t count;
2634  int c_mask;
2635  const float *c_scales;
2637  &c_weights_data_type, &c_bias_data_type,
2638  &c_dst_data_type, &count, &c_mask, &c_scales),
2639  "could not get parameters of depthwise post-op");
2640 
2641  weights_data_type = static_cast<memory::data_type>(c_weights_data_type);
2642  bias_data_type = static_cast<memory::data_type>(c_bias_data_type);
2643  dst_data_type = static_cast<memory::data_type>(c_dst_data_type);
2644  scales.resize(count);
2645 
2646  mask = c_mask;
2647  for (dnnl_dim_t c = 0; c < count; ++c)
2648  scales[c] = c_scales[c];
2649  return;
2650  }
2651 
2685  void append_dw_k3s2p1(memory::data_type weights_data_type,
2686  memory::data_type bias_data_type, memory::data_type dst_data_type,
2687  int mask, const std::vector<float> &scales) {
2688 
2690  memory::convert_to_c(weights_data_type),
2691  memory::convert_to_c(bias_data_type),
2692  memory::convert_to_c(dst_data_type),
2693  scales.size(), mask, &scales[0]),
2694  "could not append depthwise post-op");
2695  }
2696 
2711  void get_params_dw_k3s2p1(int index, memory::data_type &weights_data_type,
2712  memory::data_type &bias_data_type, memory::data_type &dst_data_type,
2713  int &mask, std::vector<float> &scales) const {
2714 
2715  dnnl_data_type_t c_weights_data_type;
2716  dnnl_data_type_t c_bias_data_type;
2717  dnnl_data_type_t c_dst_data_type;
2718  dnnl_dim_t count;
2719  int c_mask;
2720  const float *c_scales;
2722  &c_weights_data_type, &c_bias_data_type,
2723  &c_dst_data_type, &count, &c_mask, &c_scales),
2724  "could not get parameters of depthwise post-op");
2725 
2726  weights_data_type = static_cast<memory::data_type>(c_weights_data_type);
2727  bias_data_type = static_cast<memory::data_type>(c_bias_data_type);
2728  dst_data_type = static_cast<memory::data_type>(c_dst_data_type);
2729  scales.resize(count);
2730 
2731  mask = c_mask;
2732  for (dnnl_dim_t c = 0; c < count; ++c)
2733  scales[c] = c_scales[c];
2734  return;
2735  }
2736 
2751  void append_binary(algorithm aalgorithm, const memory::desc &src1_desc) {
2753  convert_to_c(aalgorithm), &src1_desc.data),
2754  "could not append a binary post-op");
2755  }
2756 
2763  int index, algorithm &aalgorithm, memory::desc &src1_desc) const {
2764  dnnl_alg_kind_t c_alg;
2765  const dnnl_memory_desc_t *data;
2767  dnnl_post_ops_get_params_binary(get(), index, &c_alg, &data),
2768  "could not get parameters of a binary post-op");
2769  aalgorithm = static_cast<dnnl::algorithm>(c_alg);
2770  src1_desc.data = *data;
2771  }
2772 };
2773 
2775 template <>
2776 struct handle_traits<dnnl_primitive_attr_t> {
2777  static dnnl_status_t destructor(dnnl_primitive_attr_t p) {
2778  return dnnl_primitive_attr_destroy(p);
2779  }
2780 };
2782 
2786 struct primitive_attr : public handle<dnnl_primitive_attr_t> {
2788 
2791  dnnl_primitive_attr_t result;
2793  "could not create primitive attribute");
2794  reset(result);
2795  }
2796 
2803  : handle<dnnl_primitive_attr_t>(attr) {}
2804 
2807  dnnl_scratchpad_mode_t result;
2810  "could not get scratchpad mode primitive attribute");
2811  return scratchpad_mode(result);
2812  }
2813 
2819  get(), dnnl::convert_to_c(mode)),
2820  "could not set scratchpad mode primitive attribute");
2821  }
2822 
2832  void get_output_scales(int &mask, std::vector<float> &scales) const {
2833  dnnl_dim_t count;
2834  int c_mask;
2835  const float *c_scales;
2837  get(), &count, &c_mask, &c_scales),
2838  "could not get output scales primitive attribute");
2839  scales.resize(count);
2840 
2841  mask = c_mask;
2842  for (dnnl_dim_t c = 0; c < count; ++c)
2843  scales[c] = c_scales[c];
2844  }
2845 
2888  void set_output_scales(int mask, const std::vector<float> &scales) {
2891  get(), (dnnl_dim_t)scales.size(), mask, scales.data()),
2892  "could not set output scales primitive attribute");
2893  }
2894 
2906  void get_scales(int arg, int &mask, std::vector<float> &scales) const {
2907  dnnl_dim_t count;
2908  int c_mask;
2909  const float *c_scales;
2911  get(), arg, &count, &c_mask, &c_scales),
2912  "could not get scales primitive attributes");
2913  scales.resize(count);
2914 
2915  mask = c_mask;
2916  for (dnnl_dim_t c = 0; c < count; ++c)
2917  scales[c] = c_scales[c];
2918  }
2919 
2936  void set_scales(int arg, int mask, const std::vector<float> &scales) {
2939  (dnnl_dim_t)scales.size(), mask, scales.data()),
2940  "could not set scales primitive attribute");
2941  }
2942 
2954  int arg, int &mask, std::vector<int32_t> &zero_points) const {
2955  dnnl_dim_t count;
2956  int c_mask;
2957  const int32_t *c_zero_points;
2959  get(), arg, &count, &c_mask, &c_zero_points),
2960  "could not get zero points primitive attribute");
2961  zero_points.resize(count);
2962 
2963  mask = c_mask;
2964  for (dnnl_dim_t c = 0; c < count; ++c)
2965  zero_points[c] = c_zero_points[c];
2966  }
2967 
2989  int arg, int mask, const std::vector<int32_t> &zero_points) {
2991  (dnnl_dim_t)zero_points.size(), mask,
2992  zero_points.data()),
2993  "could not set zero points primitive attribute");
2994  }
2995 
2999  const post_ops get_post_ops() const {
3000  post_ops result;
3001  const_dnnl_post_ops_t c_result;
3003  "could not get post-ops primitive attribute");
3004  result.reset(const_cast<dnnl_post_ops_t>(c_result), true);
3005  return result;
3006  }
3007 
3016  void set_post_ops(const post_ops ops) {
3018  "could not set post-ops primitive attribute");
3019  }
3020 
3054  void set_rnn_data_qparams(float scale, float shift) {
3057  "could not set RNN data quantization parameters primitive "
3058  "attribute");
3059  }
3060 
3070  void get_rnn_data_qparams(float &scale, float &shift) {
3071  float c_scale, c_shift;
3073  get(), &c_scale, &c_shift),
3074  "could not set RNN data quantization parameters primitive "
3075  "attribute");
3076  scale = c_scale;
3077  shift = c_shift;
3078  }
3079 
3106  void set_rnn_weights_qparams(int mask, const std::vector<float> &scales) {
3108  (int)scales.size(), mask, scales.data()),
3109  "could not set RNN weights quantization parameters primitive "
3110  "attribute");
3111  }
3112 
3132  void get_rnn_weights_qparams(int &mask, std::vector<float> &scales) {
3133  dnnl_dim_t count;
3134  int c_mask;
3135  const float *c_scales;
3137  get(), &count, &c_mask, &c_scales),
3138  "could not get primitive RNN weights quantization "
3139  "parameters attributes");
3140  scales.resize(count);
3141 
3142  mask = c_mask;
3143  for (dnnl_dim_t c = 0; c < count; c++)
3144  scales[c] = c_scales[c];
3145  }
3146 
3148  // The low-precision configuration of the RNN primitives expect input
3149  // weights to use the signed 8-bit integer data type. The scaling factors
3150  // are used to quantize floating-point data to signed integer and must be
3174  int mask, const std::vector<float> &scales) {
3177  get(), (int)scales.size(), mask, scales.data()),
3178  "could not set primitive RNN weights projection quantization "
3179  "parameters attributes");
3180  }
3181 
3202  int &mask, std::vector<float> &scales) {
3203  dnnl_dim_t count;
3204  int c_mask;
3205  const float *c_scales;
3208  get(), &count, &c_mask, &c_scales),
3209  "could not get primitive RNN weights projection quantization "
3210  "parameters attributes");
3211  scales.resize(count);
3212 
3213  mask = c_mask;
3214  for (dnnl_dim_t c = 0; c < count; c++)
3215  scales[c] = c_scales[c];
3216  }
3217 };
3218 
3220 
3223 
3225 struct primitive_desc_base : public handle<dnnl_primitive_desc_t> {
3227 
3229  primitive_desc_base() = default;
3230 
3233  engine get_engine() const { return engine::query(*this); }
3234 
3237  const char *impl_info_str() const {
3238  const char *res;
3240  get(), dnnl_query_impl_info_str, 0, &res),
3241  "could not retrieve implementation info string from a "
3242  "primitive descriptor");
3243  return res;
3244  }
3245 
3250  memory::dim res;
3252  get(), dnnl::convert_to_c(what), 0, &res);
3253  return status == dnnl_success ? res : 0;
3254  }
3255 
3270  memory::desc query_md(query what, int idx = 0) const {
3271  std::vector<query> valid_q {query::src_md, query::diff_src_md,
3275  if (!std::any_of(valid_q.cbegin(), valid_q.cend(),
3276  [=](query q) { return what == q; }))
3277  DNNL_THROW_ERROR(dnnl_invalid_arguments,
3278  "memory descriptor query is invalid");
3279 
3281  get(), dnnl::convert_to_c(what), idx);
3282  return cdesc ? memory::desc(*cdesc) : memory::desc();
3283  }
3284 
3290  memory::desc src_desc(int idx) const {
3291  return query_md(query::src_md, idx);
3292  }
3293 
3299  memory::desc dst_desc(int idx) const {
3300  return query_md(query::dst_md, idx);
3301  }
3302 
3308  memory::desc weights_desc(int idx) const {
3309  return query_md(query::weights_md, idx);
3310  }
3311 
3317  memory::desc diff_src_desc(int idx) const {
3318  return query_md(query::diff_src_md, idx);
3319  }
3320 
3326  memory::desc diff_dst_desc(int idx) const {
3327  return query_md(query::diff_dst_md, idx);
3328  }
3329 
3336  return query_md(query::diff_weights_md, idx);
3337  }
3338 
3339  // Separate versions without the index argument for documentation
3340  // purposes.
3341 
3346  memory::desc src_desc() const { return src_desc(0); }
3347 
3352  memory::desc dst_desc() const { return dst_desc(0); }
3353 
3358  memory::desc weights_desc() const { return weights_desc(0); }
3359 
3365 
3371 
3377 
3383  return query_md(query::workspace_md, 0);
3384  }
3385 
3392  return query_md(query::scratchpad_md, 0);
3393  }
3394 
3398  dnnl_engine_t c_engine;
3401  0, &c_engine),
3402  "could not retrieve scratchpad engine from a primitive "
3403  "descriptor");
3404  return engine(c_engine, true);
3405  }
3406 
3410  const_dnnl_primitive_attr_t const_c_attr;
3412  "could not get attributes from a primitive descriptor");
3413  dnnl_primitive_attr_t c_attr;
3414  error::wrap_c_api(dnnl_primitive_attr_clone(&c_attr, const_c_attr),
3415  "could not clone primitive attributes");
3416  return primitive_attr(c_attr);
3417  }
3418 
3422  dnnl_primitive_kind_t kind;
3424  dnnl_query_primitive_kind, 0, (void *)&kind),
3425  "could not get primitive kind from a primitive descriptor");
3426  return static_cast<dnnl::primitive::kind>(kind);
3427  }
3428 
3429 protected:
3434  dnnl_primitive_desc_t new_pd;
3436  "could not clone a primitive descriptor");
3437  reset(new_pd);
3438  }
3439 
3455  : primitive_desc_base(pd, prim_kind, dnnl::prop_kind::undef) {}
3456 
3469  dnnl::primitive::kind prim_kind, dnnl::prop_kind aprop_kind)
3470  : primitive_desc_base(pd, prim_kind, aprop_kind, aprop_kind) {}
3471 
3486  dnnl::primitive::kind prim_kind, dnnl::prop_kind prop_kind1,
3487  dnnl::prop_kind prop_kind2) {
3488  // It is OK to pass an empty primitive descriptor
3489  if (pd == nullptr) return;
3490 
3491  dnnl_status_t rc;
3492 
3493  dnnl_primitive_kind_t c_prim_kind = convert_to_c(prim_kind);
3494  dnnl_prop_kind_t c_prop_kind1 = convert_to_c(prop_kind1);
3495  dnnl_prop_kind_t c_prop_kind2 = convert_to_c(prop_kind2);
3496 
3497  // Check that primitive kind matches
3498  dnnl_primitive_kind_t pd_kind;
3500  pd, dnnl_query_primitive_kind, 0, (void *)&pd_kind);
3502  rc, "could not get primitive kind from a primitive descriptor");
3503  if (pd_kind != c_prim_kind)
3504  DNNL_THROW_ERROR(dnnl_invalid_arguments,
3505  "primitive descriptor operation kind mismatch");
3506 
3507  // Check that propagation kind matches
3508  dnnl_prop_kind_t pd_prop_kind;
3510  pd, dnnl_query_prop_kind, 0, (void *)&pd_prop_kind);
3511 
3512  // Something went wrong
3513  if (rc != dnnl_success && rc != dnnl_unimplemented)
3514  DNNL_THROW_ERROR(dnnl_invalid_arguments,
3515  "could not get propagation kind from the primitive "
3516  "descriptor");
3517 
3518  // Everything is fine
3519  if ((rc == dnnl_unimplemented && c_prop_kind1 == dnnl_prop_kind_undef)
3520  || (rc == dnnl_success
3521  && (pd_prop_kind == c_prop_kind1
3522  || pd_prop_kind == c_prop_kind2))) {
3523  reset_with_clone(pd);
3524  return;
3525  }
3526 
3527  // We could get the propagation kind but there is a mismatch
3528  DNNL_THROW_ERROR(dnnl_invalid_arguments,
3529  "primitive descriptor propagation kind mismatch");
3530  }
3531 
3532  using base = primitive_desc_base;
3533 };
3534 
3536 
3545 
3547 struct reorder : public primitive {
3551 
3553  primitive_desc() = default;
3554 
3572  primitive_desc(const engine &src_engine, const memory::desc &src_md,
3573  const engine &dst_engine, const memory::desc &dst_md,
3574  const primitive_attr &attr = primitive_attr(),
3575  bool allow_empty = false) {
3576  dnnl_primitive_desc_t result;
3578  &src_md.data, src_engine.get(), &dst_md.data,
3579  dst_engine.get(), attr.get());
3580  if (!allow_empty)
3582  "could not create a primitive descriptor for a reorder "
3583  "primitive");
3585  }
3586 
3598  primitive_desc(const memory &src, const memory &dst,
3599  const primitive_attr &attr = primitive_attr(),
3600  bool allow_empty = false) {
3601  dnnl_primitive_desc_t result;
3602  auto src_md = src.get_desc();
3603  auto dst_md = dst.get_desc();
3605  &src_md.data, src.get_engine().get(), &dst_md.data,
3606  dst.get_engine().get(), attr.get());
3607  if (!allow_empty)
3609  "could not create a primitive descriptor for a reorder "
3610  "primitive");
3612  }
3613 
3620 
3625  }
3626 
3631  }
3632 
3634  memory::desc src_desc() const { return base::src_desc(0); }
3635 
3637  memory::desc dst_desc() const { return base::dst_desc(0); }
3638  };
3639 
3641  reorder() = default;
3642 
3645  reorder(const primitive_desc &pd) : primitive(pd.get()) {}
3646 
3654  reorder(const memory &src, const memory &dst,
3655  const primitive_attr &attr = primitive_attr())
3656  : primitive(primitive_desc(src, dst, attr).get()) {}
3657 
3658  using primitive::execute;
3659 
3666  void execute(const stream &astream, memory &src, memory &dst) const {
3667  primitive::execute(astream, {{DNNL_ARG_FROM, src}, {DNNL_ARG_TO, dst}});
3668  }
3669 };
3670 
3672 
3680 
3682 inline std::vector<dnnl_memory_desc_t> convert_to_c(
3683  const std::vector<memory::desc> &mems) {
3684  std::vector<dnnl_memory_desc_t> c_mems;
3685  c_mems.reserve(mems.size());
3686  for (const auto &s : mems)
3687  c_mems.push_back(s.data);
3688  return c_mems;
3689 }
3691 
3693 struct concat : public primitive {
3697 
3699  primitive_desc() = default;
3700 
3711  primitive_desc(const memory::desc &dst, int concat_dimension,
3712  const std::vector<memory::desc> &srcs, const engine &aengine,
3713  const primitive_attr &attr = primitive_attr()) {
3714  auto c_srcs = convert_to_c(srcs);
3715 
3716  dnnl_primitive_desc_t result;
3719  (int)c_srcs.size(), concat_dimension, c_srcs.data(),
3720  attr.get(), aengine.get()),
3721  "could not create a primitive descriptor for a concat "
3722  "primitive");
3723  reset(result);
3724  }
3725 
3738  primitive_desc(int concat_dimension,
3739  const std::vector<memory::desc> &srcs, const engine &aengine,
3740  const primitive_attr &attr = primitive_attr()) {
3741  auto c_api_srcs = convert_to_c(srcs);
3742 
3743  dnnl_primitive_desc_t result;
3745  dnnl_concat_primitive_desc_create(&result, nullptr,
3746  (int)c_api_srcs.size(), concat_dimension,
3747  c_api_srcs.data(), attr.get(), aengine.get()),
3748  "could not create a primitive descriptor for a concat "
3749  "primitive");
3750  reset(result);
3751  }
3752 
3759 
3761  memory::desc src_desc(int idx = 0) const { return base::src_desc(idx); }
3762 
3764  memory::desc dst_desc() const { return base::dst_desc(0); }
3765  };
3766 
3768  concat() = default;
3769 
3772  concat(const primitive_desc &pd) : primitive(pd.get()) {}
3773 };
3774 
3776 
3784 
3786 struct sum : public primitive {
3790 
3792  primitive_desc() = default;
3793 
3803  const std::vector<float> &scales,
3804  const std::vector<memory::desc> &srcs, const engine &aengine,
3805  const primitive_attr &attr = primitive_attr()) {
3806  validate_container_size(scales,
3807  "counts of scales and sources are not equal",
3808  (int)srcs.size(), (int)srcs.size());
3809 
3810  auto c_api_srcs = convert_to_c(srcs);
3811 
3812  dnnl_primitive_desc_t result;
3814  dnnl_sum_primitive_desc_create(&result, &dst.data,
3815  (int)c_api_srcs.size(), scales.data(),
3816  c_api_srcs.data(), attr.get(), aengine.get()),
3817  "could not create a primitive descriptor for a sum "
3818  "primitive");
3819  reset(result);
3820  }
3821 
3832  primitive_desc(const std::vector<float> &scales,
3833  const std::vector<memory::desc> &srcs, const engine &aengine,
3834  const primitive_attr &attr = primitive_attr()) {
3835  validate_container_size(scales,
3836  "counts of scales and sources are not equal",
3837  (int)srcs.size(), (int)srcs.size());
3838 
3839  auto c_api_srcs = convert_to_c(srcs);
3840  dnnl_primitive_desc_t result;
3842  dnnl_sum_primitive_desc_create(&result, nullptr,
3843  (int)c_api_srcs.size(), scales.data(),
3844  c_api_srcs.data(), attr.get(), aengine.get()),
3845  "could not create a primitive descriptor for a sum "
3846  "primitive");
3847  reset(result);
3848  }
3849 
3856 
3858  memory::desc src_desc(int idx = 0) const { return base::src_desc(idx); }
3859 
3861  memory::desc dst_desc() const { return base::dst_desc(0); }
3862  };
3863 
3865  sum() = default;
3866 
3869  sum(const primitive_desc &pd) : primitive(pd.get()) {}
3870 };
3871 
3873 
3876 
3881 
3882  primitive_desc() = default;
3883 
3907  const engine &aengine, const_dnnl_primitive_desc_t hint_fwd_pd,
3908  bool allow_empty = false)
3909  : allow_empty_(allow_empty) {
3910  dnnl_primitive_desc_iterator_t iterator = nullptr;
3912  desc, attr ? attr->get() : nullptr, aengine.get(), hint_fwd_pd);
3913  if (!allow_empty)
3915  status, "could not create a primitive descriptor iterator");
3916  pd_iterator.reset(iterator);
3917  fetch_impl();
3918  }
3919 
3924  bool next_impl() {
3926  = dnnl_primitive_desc_iterator_next(pd_iterator.get());
3927  if (status == dnnl_iterator_ends) return false;
3929  status, "could not advance a primitive descriptor iterator");
3930  fetch_impl();
3931  return true;
3932  }
3933 
3934 private:
3935  bool allow_empty_ = false;
3937  void fetch_impl() {
3939  pd_iterator.get(allow_empty_));
3940  error::wrap_c_api(pd != nullptr || allow_empty_ ? dnnl_success
3942  "could not fetch a primitive descriptor from a primitive "
3943  "descriptor iterator");
3944  reset(pd);
3945  }
3946 };
3947 
3949 
3959 
3963  struct desc {
3965 
3996  desc(prop_kind aprop_kind, algorithm aalgorithm,
3997  const memory::desc &src_desc, const memory::desc &weights_desc,
3998  const memory::desc &bias_desc, const memory::desc &dst_desc,
3999  const memory::dims &strides, const memory::dims &padding_l,
4000  const memory::dims &padding_r) {
4001  memory::validate_dims(strides, src_desc.data.ndims - 2);
4002  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4003  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4006  dnnl::convert_to_c(aprop_kind),
4007  convert_to_c(aalgorithm), &src_desc.data,
4008  &weights_desc.data, &bias_desc.data, &dst_desc.data,
4009  &strides[0], &padding_l[0], &padding_r[0]),
4010  "could not create a descriptor for a convolution forward "
4011  "propagation primitive");
4012  }
4013 
4042  desc(prop_kind aprop_kind, algorithm aalgorithm,
4043  const memory::desc &src_desc, const memory::desc &weights_desc,
4044  const memory::desc &dst_desc, const memory::dims &strides,
4045  const memory::dims &padding_l, const memory::dims &padding_r) {
4046  memory::validate_dims(strides, src_desc.data.ndims - 2);
4047  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4048  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4051  dnnl::convert_to_c(aprop_kind),
4052  convert_to_c(aalgorithm), &src_desc.data,
4053  &weights_desc.data, nullptr, &dst_desc.data,
4054  &strides[0], &padding_l[0], &padding_r[0]),
4055  "could not create a descriptor for a convolution forward "
4056  "propagation primitive");
4057  }
4058 
4091  desc(prop_kind aprop_kind, algorithm aalgorithm,
4092  const memory::desc &src_desc, const memory::desc &weights_desc,
4093  const memory::desc &bias_desc, const memory::desc &dst_desc,
4094  const memory::dims &strides, const memory::dims &dilates,
4095  const memory::dims &padding_l, const memory::dims &padding_r) {
4096  memory::validate_dims(strides, src_desc.data.ndims - 2);
4097  memory::validate_dims(dilates, src_desc.data.ndims - 2);
4098  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4099  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4101  dnnl::convert_to_c(aprop_kind),
4102  convert_to_c(aalgorithm), &src_desc.data,
4103  &weights_desc.data, &bias_desc.data,
4104  &dst_desc.data, &strides[0], &dilates[0],
4105  &padding_l[0], &padding_r[0]),
4106  "could not create a descriptor for a dilated convolution "
4107  "forward propagation primitive");
4108  }
4109 
4140  desc(prop_kind aprop_kind, algorithm aalgorithm,
4141  const memory::desc &src_desc, const memory::desc &weights_desc,
4142  const memory::desc &dst_desc, const memory::dims &strides,
4143  const memory::dims &dilates, const memory::dims &padding_l,
4144  const memory::dims &padding_r) {
4145  memory::validate_dims(strides, src_desc.data.ndims - 2);
4146  memory::validate_dims(dilates, src_desc.data.ndims - 2);
4147  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4148  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4150  dnnl::convert_to_c(aprop_kind),
4151  convert_to_c(aalgorithm), &src_desc.data,
4152  &weights_desc.data, nullptr,
4153  &dst_desc.data, &strides[0], &dilates[0],
4154  &padding_l[0], &padding_r[0]),
4155  "could not create a descriptor for a dilated convolution "
4156  "forward propagation primitive");
4157  }
4158  };
4159 
4163  primitive_desc() = default;
4164 
4175  primitive_desc(const desc &adesc, const engine &aengine,
4176  bool allow_empty = false)
4177  : dnnl::primitive_desc(
4178  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
4179 
4191  primitive_desc(const desc &adesc, const primitive_attr &attr,
4192  const engine &aengine, bool allow_empty = false)
4193  : dnnl::primitive_desc(
4194  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
4195 
4203  : dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution,
4206 
4208  memory::desc src_desc() const { return base::src_desc(0); }
4209 
4212 
4214  memory::desc dst_desc() const { return base::dst_desc(0); }
4215 
4221  };
4222 
4224  convolution_forward() = default;
4225 
4230 };
4231 
4234 
4236  struct desc {
4238 
4264  desc(algorithm aalgorithm, const memory::desc &diff_src_desc,
4265  const memory::desc &weights_desc,
4266  const memory::desc &diff_dst_desc, const memory::dims &strides,
4267  const memory::dims &padding_l, const memory::dims &padding_r) {
4268  memory::validate_dims(strides, diff_src_desc.data.ndims - 2);
4269  memory::validate_dims(padding_l, diff_src_desc.data.ndims - 2);
4270  memory::validate_dims(padding_r, diff_src_desc.data.ndims - 2);
4273  convert_to_c(aalgorithm), &diff_src_desc.data,
4274  &weights_desc.data, &diff_dst_desc.data,
4275  &strides[0], &padding_l[0], &padding_r[0]),
4276  "could not create a descriptor for a convolution backward "
4277  "propagation primitive");
4278  }
4279 
4307  desc(algorithm aalgorithm, const memory::desc &diff_src_desc,
4308  const memory::desc &weights_desc,
4309  const memory::desc &diff_dst_desc, const memory::dims &strides,
4310  const memory::dims &dilates, const memory::dims &padding_l,
4311  const memory::dims &padding_r) {
4312  memory::validate_dims(strides, diff_src_desc.data.ndims - 2);
4313  memory::validate_dims(dilates, diff_src_desc.data.ndims - 2);
4314  memory::validate_dims(padding_l, diff_src_desc.data.ndims - 2);
4315  memory::validate_dims(padding_r, diff_src_desc.data.ndims - 2);
4318  convert_to_c(aalgorithm), &diff_src_desc.data,
4319  &weights_desc.data, &diff_dst_desc.data,
4320  &strides[0], &dilates[0], &padding_l[0],
4321  &padding_r[0]),
4322  "could not create a descriptor for a dilated convolution "
4323  "backward propagation primitive");
4324  }
4325  };
4326 
4330  primitive_desc() = default;
4331 
4345  primitive_desc(const desc &adesc, const engine &aengine,
4346  const convolution_forward::primitive_desc &hint_fwd_pd,
4347  bool allow_empty = false)
4348  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
4349  hint_fwd_pd.get(), allow_empty) {}
4350 
4365  primitive_desc(const desc &adesc, const primitive_attr &attr,
4366  const engine &aengine,
4367  const convolution_forward::primitive_desc &hint_fwd_pd,
4368  bool allow_empty = false)
4369  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
4370  hint_fwd_pd.get(), allow_empty) {}
4371 
4379  : dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution,
4381 
4384 
4387 
4390  };
4391 
4394 
4399 };
4400 
4404  struct desc {
4406 
4434  desc(algorithm aalgorithm, const memory::desc &src_desc,
4435  const memory::desc &diff_weights_desc,
4436  const memory::desc &diff_bias_desc,
4437  const memory::desc &diff_dst_desc, const memory::dims &strides,
4438  const memory::dims &padding_l, const memory::dims &padding_r) {
4439  memory::validate_dims(strides, src_desc.data.ndims - 2);
4440  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4441  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4444  convert_to_c(aalgorithm), &src_desc.data,
4445  &diff_weights_desc.data, &diff_bias_desc.data,
4446  &diff_dst_desc.data, &strides[0], &padding_l[0],
4447  &padding_r[0]),
4448  "could not create a descriptor for a convolution weights "
4449  "update primitive");
4450  }
4451 
4477  desc(algorithm aalgorithm, const memory::desc &src_desc,
4478  const memory::desc &diff_weights_desc,
4479  const memory::desc &diff_dst_desc, const memory::dims &strides,
4480  const memory::dims &padding_l, const memory::dims &padding_r) {
4481  memory::validate_dims(strides, src_desc.data.ndims - 2);
4482  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4483  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4485  convert_to_c(aalgorithm), &src_desc.data,
4486  &diff_weights_desc.data, nullptr,
4487  &diff_dst_desc.data, &strides[0],
4488  &padding_l[0], &padding_r[0]),
4489  "could not create a descriptor for a convolution weights "
4490  "update primitive");
4491  }
4492 
4522  desc(algorithm aalgorithm, const memory::desc &src_desc,
4523  const memory::desc &diff_weights_desc,
4524  const memory::desc &diff_bias_desc,
4525  const memory::desc &diff_dst_desc, const memory::dims &strides,
4526  const memory::dims &dilates, const memory::dims &padding_l,
4527  const memory::dims &padding_r) {
4528  memory::validate_dims(strides, src_desc.data.ndims - 2);
4529  memory::validate_dims(dilates, src_desc.data.ndims - 2);
4530  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4531  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4534  convert_to_c(aalgorithm), &src_desc.data,
4535  &diff_weights_desc.data, &diff_bias_desc.data,
4536  &diff_dst_desc.data, &strides[0], &dilates[0],
4537  &padding_l[0], &padding_r[0]),
4538  "could not create a descriptor for a dilated convolution "
4539  "weights gradient primitive");
4540  }
4541 
4569  desc(algorithm aalgorithm, const memory::desc &src_desc,
4570  const memory::desc &diff_weights_desc,
4571  const memory::desc &diff_dst_desc, const memory::dims &strides,
4572  const memory::dims &dilates, const memory::dims &padding_l,
4573  const memory::dims &padding_r) {
4574  memory::validate_dims(strides, src_desc.data.ndims - 2);
4575  memory::validate_dims(dilates, src_desc.data.ndims - 2);
4576  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4577  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4580  convert_to_c(aalgorithm), &src_desc.data,
4581  &diff_weights_desc.data, nullptr,
4582  &diff_dst_desc.data, &strides[0], &dilates[0],
4583  &padding_l[0], &padding_r[0]),
4584  "could not create a descriptor for a dilated convolution "
4585  "weights gradient primitive");
4586  }
4587  };
4588 
4592  primitive_desc() = default;
4593 
4606  primitive_desc(const desc &adesc, const engine &aengine,
4607  const convolution_forward::primitive_desc &hint_fwd_pd,
4608  bool allow_empty = false)
4609  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
4610  hint_fwd_pd.get(), allow_empty) {}
4611 
4625  primitive_desc(const desc &adesc, const primitive_attr &attr,
4626  const engine &aengine,
4627  const convolution_forward::primitive_desc &hint_fwd_pd,
4628  bool allow_empty = false)
4629  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
4630  hint_fwd_pd.get(), allow_empty) {}
4631 
4639  : dnnl::primitive_desc(pd, dnnl::primitive::kind::convolution,
4641 
4643  memory::desc src_desc() const { return base::src_desc(0); }
4644 
4647  return base::diff_weights_desc(0);
4648  }
4649 
4652 
4658  return base::diff_weights_desc(1);
4659  }
4660  };
4661 
4664 
4669 };
4670 
4672 //
4680 
4684  struct desc {
4686 
4716  desc(prop_kind aprop_kind, algorithm aalgorithm,
4717  const memory::desc &src_desc, const memory::desc &weights_desc,
4718  const memory::desc &bias_desc, const memory::desc &dst_desc,
4719  const memory::dims &strides, const memory::dims &padding_l,
4720  const memory::dims &padding_r) {
4721  memory::validate_dims(strides, src_desc.data.ndims - 2);
4722  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4723  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4726  dnnl::convert_to_c(aprop_kind),
4727  convert_to_c(aalgorithm), &src_desc.data,
4728  &weights_desc.data, &bias_desc.data, &dst_desc.data,
4729  &strides[0], &padding_l[0], &padding_r[0]),
4730  "could not create a descriptor for a deconvolution forward "
4731  "propagation primitive");
4732  }
4733 
4761  desc(prop_kind aprop_kind, algorithm aalgorithm,
4762  const memory::desc &src_desc, const memory::desc &weights_desc,
4763  const memory::desc &dst_desc, const memory::dims &strides,
4764  const memory::dims &padding_l, const memory::dims &padding_r) {
4765  memory::validate_dims(strides, src_desc.data.ndims - 2);
4766  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4767  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4770  dnnl::convert_to_c(aprop_kind),
4771  convert_to_c(aalgorithm), &src_desc.data,
4772  &weights_desc.data, nullptr, &dst_desc.data,
4773  &strides[0], &padding_l[0], &padding_r[0]),
4774  "could not create a descriptor for a deconvolution forward "
4775  "propagation primitive");
4776  }
4777 
4809  desc(prop_kind aprop_kind, algorithm aalgorithm,
4810  const memory::desc &src_desc, const memory::desc &weights_desc,
4811  const memory::desc &bias_desc, const memory::desc &dst_desc,
4812  const memory::dims &strides, const memory::dims &dilates,
4813  const memory::dims &padding_l, const memory::dims &padding_r) {
4814  memory::validate_dims(strides, src_desc.data.ndims - 2);
4815  memory::validate_dims(dilates, src_desc.data.ndims - 2);
4816  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4817  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4819  &data, dnnl::convert_to_c(aprop_kind),
4820  convert_to_c(aalgorithm), &src_desc.data,
4821  &weights_desc.data, &bias_desc.data,
4822  &dst_desc.data, &strides[0], &dilates[0],
4823  &padding_l[0], &padding_r[0]),
4824  "could not create a descriptor for a dilated deconvolution "
4825  "forward propagation primitive");
4826  }
4827 
4857  desc(prop_kind aprop_kind, algorithm aalgorithm,
4858  const memory::desc &src_desc, const memory::desc &weights_desc,
4859  const memory::desc &dst_desc, const memory::dims &strides,
4860  const memory::dims &dilates, const memory::dims &padding_l,
4861  const memory::dims &padding_r) {
4862  memory::validate_dims(strides, src_desc.data.ndims - 2);
4863  memory::validate_dims(dilates, src_desc.data.ndims - 2);
4864  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
4865  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
4867  &data, dnnl::convert_to_c(aprop_kind),
4868  convert_to_c(aalgorithm), &src_desc.data,
4869  &weights_desc.data, nullptr,
4870  &dst_desc.data, &strides[0], &dilates[0],
4871  &padding_l[0], &padding_r[0]),
4872  "could not create a descriptor for a dilated deconvolution "
4873  "forward propagation primitive");
4874  }
4875  };
4876 
4880  primitive_desc() = default;
4881 
4892  primitive_desc(const desc &adesc, const engine &aengine,
4893  bool allow_empty = false)
4894  : dnnl::primitive_desc(
4895  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
4896 
4908  primitive_desc(const desc &adesc, const primitive_attr &attr,
4909  const engine &aengine, bool allow_empty = false)
4910  : dnnl::primitive_desc(
4911  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
4912 
4920  : dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution,
4923 
4925  memory::desc src_desc() const { return base::src_desc(0); }
4926 
4929 
4931  memory::desc dst_desc() const { return base::dst_desc(0); }
4932 
4935  };
4936 
4939 
4944 };
4945 
4949  struct desc {
4951 
4976  desc(algorithm aalgorithm, const memory::desc &diff_src_desc,
4977  const memory::desc &weights_desc,
4978  const memory::desc &diff_dst_desc, const memory::dims &strides,
4979  const memory::dims &padding_l, const memory::dims &padding_r) {
4980  memory::validate_dims(strides, diff_src_desc.data.ndims - 2);
4981  memory::validate_dims(padding_l, diff_src_desc.data.ndims - 2);
4982  memory::validate_dims(padding_r, diff_src_desc.data.ndims - 2);
4985  convert_to_c(aalgorithm), &diff_src_desc.data,
4986  &weights_desc.data, &diff_dst_desc.data,
4987  &strides[0], &padding_l[0], &padding_r[0]),
4988  "could not create a descriptor for a deconvolution "
4989  "backward propagation primitive");
4990  }
4991 
5018  desc(algorithm aalgorithm, const memory::desc &diff_src_desc,
5019  const memory::desc &weights_desc,
5020  const memory::desc &diff_dst_desc, const memory::dims &strides,
5021  const memory::dims &dilates, const memory::dims &padding_l,
5022  const memory::dims &padding_r) {
5023  memory::validate_dims(strides, diff_src_desc.data.ndims - 2);
5024  memory::validate_dims(dilates, diff_src_desc.data.ndims - 2);
5025  memory::validate_dims(padding_l, diff_src_desc.data.ndims - 2);
5026  memory::validate_dims(padding_r, diff_src_desc.data.ndims - 2);
5029  convert_to_c(aalgorithm), &diff_src_desc.data,
5030  &weights_desc.data, &diff_dst_desc.data,
5031  &strides[0], &dilates[0], &padding_l[0],
5032  &padding_r[0]),
5033  "could not create a descriptor for a dilated deconvolution "
5034  "backward propagation primitive");
5035  }
5036  };
5037 
5041  primitive_desc() = default;
5042 
5056  primitive_desc(const desc &adesc, const engine &aengine,
5057  const deconvolution_forward::primitive_desc &hint_fwd_pd,
5058  bool allow_empty = false)
5059  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
5060  hint_fwd_pd.get(), allow_empty) {}
5061 
5076  primitive_desc(const desc &adesc, const primitive_attr &attr,
5077  const engine &aengine,
5078  const deconvolution_forward::primitive_desc &hint_fwd_pd,
5079  bool allow_empty = false)
5080  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
5081  hint_fwd_pd.get(), allow_empty) {}
5082 
5090  : dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution,
5092 
5095 
5098 
5101  };
5102 
5105 
5110 };
5111 
5115  struct desc {
5117 
5144  desc(algorithm aalgorithm, const memory::desc &src_desc,
5145  const memory::desc &diff_weights_desc,
5146  const memory::desc &diff_bias_desc,
5147  const memory::desc &diff_dst_desc, const memory::dims &strides,
5148  const memory::dims &padding_l, const memory::dims &padding_r) {
5149  memory::validate_dims(strides, src_desc.data.ndims - 2);
5150  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
5151  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
5154  convert_to_c(aalgorithm), &src_desc.data,
5155  &diff_weights_desc.data, &diff_bias_desc.data,
5156  &diff_dst_desc.data, &strides[0], &padding_l[0],
5157  &padding_r[0]),
5158  "could not create a descriptor for a deconvolution weights "
5159  "update primitive");
5160  }
5161 
5186  desc(algorithm aalgorithm, const memory::desc &src_desc,
5187  const memory::desc &diff_weights_desc,
5188  const memory::desc &diff_dst_desc, const memory::dims &strides,
5189  const memory::dims &padding_l, const memory::dims &padding_r) {
5190  memory::validate_dims(strides, src_desc.data.ndims - 2);
5191  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
5192  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
5194  &data, convert_to_c(aalgorithm),
5195  &src_desc.data, &diff_weights_desc.data,
5196  nullptr, &diff_dst_desc.data, &strides[0],
5197  &padding_l[0], &padding_r[0]),
5198  "could not create a descriptor for a deconvolution weights "
5199  "update primitive");
5200  }
5201 
5230  desc(algorithm aalgorithm, const memory::desc &src_desc,
5231  const memory::desc &diff_weights_desc,
5232  const memory::desc &diff_bias_desc,
5233  const memory::desc &diff_dst_desc, const memory::dims &strides,
5234  const memory::dims &dilates, const memory::dims &padding_l,
5235  const memory::dims &padding_r) {
5236  memory::validate_dims(strides, src_desc.data.ndims - 2);
5237  memory::validate_dims(dilates, src_desc.data.ndims - 2);
5238  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
5239  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
5242  convert_to_c(aalgorithm), &src_desc.data,
5243  &diff_weights_desc.data, &diff_bias_desc.data,
5244  &diff_dst_desc.data, &strides[0], &dilates[0],
5245  &padding_l[0], &padding_r[0]),
5246  "could not create a descriptor for a dilated deconvolution "
5247  "weights gradient primitive");
5248  }
5249 
5276  desc(algorithm aalgorithm, const memory::desc &src_desc,
5277  const memory::desc &diff_weights_desc,
5278  const memory::desc &diff_dst_desc, const memory::dims &strides,
5279  const memory::dims &dilates, const memory::dims &padding_l,
5280  const memory::dims &padding_r) {
5281  memory::validate_dims(strides, src_desc.data.ndims - 2);
5282  memory::validate_dims(dilates, src_desc.data.ndims - 2);
5283  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
5284  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
5287  convert_to_c(aalgorithm), &src_desc.data,
5288  &diff_weights_desc.data, nullptr,
5289  &diff_dst_desc.data, &strides[0], &dilates[0],
5290  &padding_l[0], &padding_r[0]),
5291  "could not create a descriptor for a dilated deconvolution "
5292  "weights gradient primitive");
5293  }
5294  };
5295 
5299  primitive_desc() = default;
5300 
5314  primitive_desc(const desc &adesc, const engine &aengine,
5315  const deconvolution_forward::primitive_desc &hint_fwd_pd,
5316  bool allow_empty = false)
5317  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
5318  hint_fwd_pd.get(), allow_empty) {}
5319 
5334  primitive_desc(const desc &adesc, const primitive_attr &attr,
5335  const engine &aengine,
5336  const deconvolution_forward::primitive_desc &hint_fwd_pd,
5337  bool allow_empty = false)
5338  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
5339  hint_fwd_pd.get(), allow_empty) {}
5340 
5348  : dnnl::primitive_desc(pd, dnnl::primitive::kind::deconvolution,
5350 
5352  memory::desc src_desc() const { return base::src_desc(0); }
5353 
5356  return base::diff_weights_desc(0);
5357  }
5358 
5361 
5364  return base::diff_weights_desc(1);
5365  }
5366  };
5367 
5370 
5375 };
5376 
5378 
5387 
5389 struct lrn_forward : public primitive {
5391  struct desc {
5392  dnnl_lrn_desc_t data;
5393 
5407  desc(prop_kind aprop_kind, algorithm aalgorithm,
5408  const memory::desc &data_desc, memory::dim local_size,
5409  float alpha, float beta, float k = 1.f) {
5411  dnnl::convert_to_c(aprop_kind),
5412  convert_to_c(aalgorithm), &data_desc.data,
5413  local_size, alpha, beta, k),
5414  "could not create a descriptor for a lrn forward "
5415  "propagation primitive");
5416  }
5417  };
5418 
5422  primitive_desc() = default;
5423 
5433  primitive_desc(const desc &adesc, const engine &aengine,
5434  bool allow_empty = false)
5435  : dnnl::primitive_desc(
5436  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
5437 
5448  primitive_desc(const desc &adesc, const primitive_attr &attr,
5449  const engine &aengine, bool allow_empty = false)
5450  : dnnl::primitive_desc(
5451  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
5452 
5460  : dnnl::primitive_desc(pd, dnnl::primitive::kind::lrn,
5463 
5465  memory::desc src_desc() const { return base::src_desc(0); }
5466 
5468  memory::desc dst_desc() const { return base::dst_desc(0); }
5469 
5472  };
5473 
5475  lrn_forward() = default;
5476 
5481 };
5482 
5484 struct lrn_backward : public primitive {
5486  struct desc {
5487  dnnl_lrn_desc_t data;
5488 
5501  desc(algorithm aalgorithm, const memory::desc &data_desc,
5502  const memory::desc &diff_data_desc, memory::dim local_size,
5503  float alpha, float beta, float k = 1.f) {
5505  dnnl_lrn_backward_desc_init(&data, convert_to_c(aalgorithm),
5506  &diff_data_desc.data, &data_desc.data, local_size,
5507  alpha, beta, k),
5508  "could not create a descriptor for a lrn backward "
5509  "propagation primitive");
5510  }
5511  };
5512 
5516  primitive_desc() = default;
5517 
5530  primitive_desc(const desc &adesc, const engine &aengine,
5531  const lrn_forward::primitive_desc &hint_fwd_pd,
5532  bool allow_empty = false)
5533  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
5534  hint_fwd_pd.get(), allow_empty) {}
5535 
5549  primitive_desc(const desc &adesc, const primitive_attr &attr,
5550  const engine &aengine,
5551  const lrn_forward::primitive_desc &hint_fwd_pd,
5552  bool allow_empty = false)
5553  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
5554  hint_fwd_pd.get(), allow_empty) {}
5555 
5563  : dnnl::primitive_desc(pd, dnnl::primitive::kind::lrn,
5565 
5568 
5571 
5574  };
5575 
5577  lrn_backward() = default;
5578 
5583 };
5584 
5586 
5594 
5596 struct pooling_forward : public primitive {
5598  struct desc {
5599  dnnl_pooling_desc_t data;
5600 
5625  desc(prop_kind aprop_kind, algorithm aalgorithm,
5626  const memory::desc &src_desc, const memory::desc &dst_desc,
5627  const memory::dims &strides, const memory::dims &kernel,
5628  const memory::dims &padding_l, const memory::dims &padding_r) {
5629  memory::validate_dims(strides, src_desc.data.ndims - 2);
5630  memory::validate_dims(kernel, src_desc.data.ndims - 2);
5631  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
5632  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
5634  dnnl::convert_to_c(aprop_kind),
5635  convert_to_c(aalgorithm), &src_desc.data,
5636  &dst_desc.data, &strides[0], &kernel[0],
5637  &padding_l[0], &padding_r[0]),
5638  "could not create a descriptor for a pooling forward "
5639  "propagation primitive");
5640  }
5641  };
5642 
5646  primitive_desc() = default;
5647 
5657  primitive_desc(const desc &adesc, const engine &aengine,
5658  bool allow_empty = false)
5659  : dnnl::primitive_desc(
5660  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
5661 
5672  primitive_desc(const desc &adesc, const primitive_attr &attr,
5673  const engine &aengine, bool allow_empty = false)
5674  : dnnl::primitive_desc(
5675  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
5676 
5684  : dnnl::primitive_desc(pd, dnnl::primitive::kind::pooling,
5687 
5689  memory::desc src_desc() const { return base::src_desc(0); }
5690 
5692  memory::desc dst_desc() const { return base::dst_desc(0); }
5693 
5696  };
5697 
5699  pooling_forward() = default;
5700 
5705 };
5706 
5708 struct pooling_backward : public primitive {
5710  struct desc {
5711  dnnl_pooling_desc_t data;
5712 
5734  desc(algorithm aalgorithm, const memory::desc &diff_src_desc,
5735  const memory::desc &diff_dst_desc, const memory::dims &strides,
5736  const memory::dims &kernel, const memory::dims &padding_l,
5737  const memory::dims &padding_r) {
5738  memory::validate_dims(strides, diff_src_desc.data.ndims - 2);
5739  memory::validate_dims(kernel, diff_src_desc.data.ndims - 2);
5740  memory::validate_dims(padding_l, diff_src_desc.data.ndims - 2);
5741  memory::validate_dims(padding_r, diff_src_desc.data.ndims - 2);
5744  convert_to_c(aalgorithm), &diff_src_desc.data,
5745  &diff_dst_desc.data, &strides[0], &kernel[0],
5746  &padding_l[0], &padding_r[0]),
5747  "could not create a descriptor for a pooling backward "
5748  "propagation primitive");
5749  }
5750  };
5751 
5755  primitive_desc() = default;
5756 
5769  primitive_desc(const desc &adesc, const engine &aengine,
5770  const pooling_forward::primitive_desc &hint_fwd_pd,
5771  bool allow_empty = false)
5772  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
5773  hint_fwd_pd.get(), allow_empty) {}
5774 
5788  primitive_desc(const desc &adesc, const primitive_attr &attr,
5789  const engine &aengine,
5790  const pooling_forward::primitive_desc &hint_fwd_pd,
5791  bool allow_empty = false)
5792  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
5793  hint_fwd_pd.get(), allow_empty) {}
5794 
5802  : dnnl::primitive_desc(pd, dnnl::primitive::kind::pooling,
5804 
5807 
5810 
5813  };
5814 
5816  pooling_backward() = default;
5817 
5822 };
5823 
5825 
5846 
5848 struct eltwise_forward : public primitive {
5850  struct desc {
5851  dnnl_eltwise_desc_t data;
5852 
5865  desc(prop_kind aprop_kind, algorithm aalgorithm,
5866  const memory::desc &data_desc, float alpha = 0,
5867  float beta = 0) {
5869  dnnl::convert_to_c(aprop_kind),
5870  dnnl::convert_to_c(aalgorithm),
5871  &data_desc.data, alpha, beta),
5872  "could not create a descriptor for an eltwise forward "
5873  "propagation primitive");
5874  }
5875  };
5876 
5880  primitive_desc() = default;
5881 
5892  primitive_desc(const desc &adesc, const engine &aengine,
5893  bool allow_empty = false)
5894  : dnnl::primitive_desc(
5895  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
5896 
5908  primitive_desc(const desc &adesc, const primitive_attr &attr,
5909  const engine &aengine, bool allow_empty = false)
5910  : dnnl::primitive_desc(
5911  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
5912 
5920  : dnnl::primitive_desc(pd, dnnl::primitive::kind::eltwise,
5923 
5925  memory::desc src_desc() const { return base::src_desc(0); }
5926 
5928  memory::desc dst_desc() const { return base::dst_desc(0); }
5929  };
5930 
5932  eltwise_forward() = default;
5933 
5938 };
5939 
5941 struct eltwise_backward : public primitive {
5943  struct desc {
5944  dnnl_eltwise_desc_t data;
5945 
5957  desc(algorithm aalgorithm, const memory::desc &diff_data_desc,
5958  const memory::desc &data_desc, float alpha = 0,
5959  float beta = 0) {
5962  dnnl::convert_to_c(aalgorithm),
5963  &diff_data_desc.data, &data_desc.data, alpha, beta),
5964  "could not create a descriptor for an eltwise backward "
5965  "propagation primitive");
5966  }
5967  };
5968 
5972  primitive_desc() = default;
5973 
5987  primitive_desc(const desc &adesc, const engine &aengine,
5988  const eltwise_forward::primitive_desc &hint_fwd_pd,
5989  bool allow_empty = false)
5990  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
5991  hint_fwd_pd.get(), allow_empty) {}
5992 
6007  primitive_desc(const desc &adesc, const primitive_attr &attr,
6008  const engine &aengine,
6009  const eltwise_forward::primitive_desc &hint_fwd_pd,
6010  bool allow_empty = false)
6011  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
6012  hint_fwd_pd.get(), allow_empty) {}
6013 
6021  : dnnl::primitive_desc(pd, dnnl::primitive::kind::eltwise,
6023 
6025  memory::desc src_desc() const { return base::src_desc(0); }
6026 
6029 
6032  };
6033 
6035  eltwise_backward() = default;
6036 
6041 };
6042 
6044 
6052 
6054 struct softmax_forward : public primitive {
6056  struct desc {
6057  dnnl_softmax_desc_t data;
6058 
6060  desc() = default;
6061 
6070  desc(prop_kind aprop_kind, const memory::desc &data_desc,
6071  int softmax_axis) {
6073  dnnl::convert_to_c(aprop_kind),
6074  &data_desc.data, softmax_axis),
6075  "could not create a descriptor for a softmax forward "
6076  "propagation primitive");
6077  }
6078  };
6079 
6083  primitive_desc() = default;
6084 
6095  primitive_desc(const desc &adesc, const engine &aengine,
6096  bool allow_empty = false)
6097  : dnnl::primitive_desc(
6098  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
6099 
6111  primitive_desc(const desc &adesc, const primitive_attr &attr,
6112  const engine &aengine, bool allow_empty = false)
6113  : dnnl::primitive_desc(
6114  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
6115 
6123  : dnnl::primitive_desc(pd, dnnl::primitive::kind::softmax,
6126 
6128  memory::desc src_desc() const { return base::src_desc(0); }
6129 
6131  memory::desc dst_desc() const { return base::dst_desc(0); }
6132  };
6133 
6135  softmax_forward() = default;
6136 
6141 };
6142 
6144 struct softmax_backward : public primitive {
6146  struct desc {
6147  dnnl_softmax_desc_t data;
6148 
6150  desc() = default;
6151 
6159  desc(const memory::desc &diff_data_desc, const memory::desc &data_desc,
6160  int softmax_axis) {
6162  dnnl_softmax_backward_desc_init(&data, &diff_data_desc.data,
6163  &data_desc.data, softmax_axis),
6164  "could not create a descriptor for a softmax backward "
6165  "propagation primitive");
6166  }
6167  };
6168 
6172  primitive_desc() = default;
6173 
6187  primitive_desc(const desc &adesc, const engine &aengine,
6188  const softmax_forward::primitive_desc &hint_fwd_pd,
6189  bool allow_empty = false)
6190  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
6191  hint_fwd_pd.get(), allow_empty) {}
6192 
6207  primitive_desc(const desc &adesc, const primitive_attr &attr,
6208  const engine &aengine,
6209  const softmax_forward::primitive_desc &hint_fwd_pd,
6210  bool allow_empty = false)
6211  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
6212  hint_fwd_pd.get(), allow_empty) {}
6213 
6221  : dnnl::primitive_desc(pd, dnnl::primitive::kind::softmax,
6223 
6225  memory::desc dst_desc() const { return base::dst_desc(0); }
6226 
6229 
6232  };
6233 
6235  softmax_backward() = default;
6236 
6241 };
6242 
6244 
6252 
6256  struct desc {
6258 
6260  desc() = default;
6261 
6270  desc(prop_kind aprop_kind, const memory::desc &data_desc,
6271  int logsoftmax_axis) {
6273  dnnl::convert_to_c(aprop_kind),
6274  &data_desc.data, logsoftmax_axis),
6275  "could not create a descriptor for a logsoftmax forward "
6276  "propagation primitive");
6277  }
6278  };
6279 
6283  primitive_desc() = default;
6284 
6295  primitive_desc(const desc &adesc, const engine &aengine,
6296  bool allow_empty = false)
6297  : dnnl::primitive_desc(
6298  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
6299 
6311  primitive_desc(const desc &adesc, const primitive_attr &attr,
6312  const engine &aengine, bool allow_empty = false)
6313  : dnnl::primitive_desc(
6314  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
6315 
6323  : dnnl::primitive_desc(pd,
6324  // Logsoftmax and softmax share the implementation and
6325  // currently report the same primitive kind. Hence this
6326  // must be softmax and not logsoftmax.
6327  dnnl::primitive::kind::softmax,
6330 
6332  memory::desc src_desc() const { return base::src_desc(0); }
6333 
6335  memory::desc dst_desc() const { return base::dst_desc(0); }
6336  };
6337 
6339  logsoftmax_forward() = default;
6340 
6345 };
6346 
6350  struct desc {
6352 
6354  desc() = default;
6355 
6363  desc(const memory::desc &diff_data_desc, const memory::desc &data_desc,
6364  int logsoftmax_axis) {
6366  &diff_data_desc.data, &data_desc.data,
6367  logsoftmax_axis),
6368  "could not create a descriptor for a logsoftmax backward "
6369  "propagation primitive");
6370  }
6371  };
6372 
6376  primitive_desc() = default;
6377 
6391  primitive_desc(const desc &adesc, const engine &aengine,
6392  const logsoftmax_forward::primitive_desc &hint_fwd_pd,
6393  bool allow_empty = false)
6394  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
6395  hint_fwd_pd.get(), allow_empty) {}
6396 
6411  primitive_desc(const desc &adesc, const primitive_attr &attr,
6412  const engine &aengine,
6413  const logsoftmax_forward::primitive_desc &hint_fwd_pd,
6414  bool allow_empty = false)
6415  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
6416  hint_fwd_pd.get(), allow_empty) {}
6417 
6425  : dnnl::primitive_desc(pd,
6426  // Logsoftmax and softmax share the implementation and
6427  // currently report the same primitive kind. Hence this
6428  // must be softmax and not logsoftmax.
6429  dnnl::primitive::kind::softmax,
6431 
6433  memory::desc dst_desc() const { return base::dst_desc(0); }
6434 
6437 
6440  };
6441 
6443  logsoftmax_backward() = default;
6444 
6449 };
6450 
6452 
6472 
6476  struct desc {
6478 
6493  desc(prop_kind aprop_kind, const memory::desc &data_desc, float epsilon,
6494  normalization_flags flags) {
6497  dnnl::convert_to_c(aprop_kind), &data_desc.data,
6498  epsilon, convert_to_c(flags)),
6499  "could not create a descriptor for a batch normalization "
6500  "forward propagation primitive");
6501  }
6502  };
6503 
6508  primitive_desc() = default;
6509 
6520  primitive_desc(const desc &adesc, const engine &aengine,
6521  bool allow_empty = false)
6522  : dnnl::primitive_desc(
6523  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
6524 
6536  primitive_desc(const desc &adesc, const primitive_attr &attr,
6537  const engine &aengine, bool allow_empty = false)
6538  : dnnl::primitive_desc(
6539  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
6540 
6548  : dnnl::primitive_desc(pd,
6549  dnnl::primitive::kind::batch_normalization,
6552 
6554  memory::desc src_desc() const { return base::src_desc(0); }
6555 
6557  memory::desc dst_desc() const { return base::dst_desc(0); }
6558 
6561 
6564 
6567  memory::desc mean_desc() const { return stat_desc(mean); }
6568 
6571  memory::desc variance_desc() const { return stat_desc(var); }
6572 
6573  private:
6574  enum {
6575  mean = 1,
6576  var = 2,
6577  };
6578  memory::desc stat_desc(int kind) const {
6583  &p),
6584  "could not retrieve a descriptor from a primitive "
6585  "descriptor for batch normalization forward propagation "
6586  "primitive");
6587  return query_md(p->flags & dnnl_use_global_stats ? query::src_md
6588  : query::dst_md,
6589  kind);
6590  }
6591  };
6592 
6595 
6600 };
6601 
6605  struct desc {
6607 
6620  desc(prop_kind aprop_kind, const memory::desc &diff_data_desc,
6621  const memory::desc &data_desc, float epsilon,
6622  normalization_flags flags) {
6624  dnnl::convert_to_c(aprop_kind),
6625  &diff_data_desc.data, &data_desc.data,
6626  epsilon, convert_to_c(flags)),
6627  "could not create a descriptor for a batch normalization "
6628  "backward propagation primitive");
6629  }
6630  };
6631 
6636  primitive_desc() = default;
6637 
6651  primitive_desc(const desc &adesc, const engine &aengine,
6653  bool allow_empty = false)
6654  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
6655  hint_fwd_pd.get(), allow_empty) {}
6656 
6671  primitive_desc(const desc &adesc, const primitive_attr &attr,
6672  const engine &aengine,
6674  bool allow_empty = false)
6675  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
6676  hint_fwd_pd.get(), allow_empty) {}
6677 
6685  : dnnl::primitive_desc(pd,
6686  dnnl::primitive::kind::batch_normalization,
6688  }
6689 
6691  memory::desc src_desc() const { return base::src_desc(0); }
6692 
6695 
6697  memory::desc dst_desc() const { return base::dst_desc(0); }
6698 
6701 
6704 
6707  return base::diff_weights_desc(0);
6708  }
6709 
6712 
6715  return query_md(query::src_md, 2);
6716  }
6717 
6720  };
6721 
6724 
6729 };
6730 
6732 
6754 
6758  struct desc {
6760 
6772  desc(prop_kind aprop_kind, const memory::desc &data_desc,
6773  const memory::desc &stat_desc, float epsilon,
6774  normalization_flags flags) {
6777  dnnl::convert_to_c(aprop_kind), &data_desc.data,
6778  &stat_desc.data, epsilon, convert_to_c(flags)),
6779  "could not create a descriptor for a layer normalization "
6780  "forward propagation primitive");
6781  }
6782 
6793  desc(prop_kind aprop_kind, const memory::desc &data_desc, float epsilon,
6794  normalization_flags flags) {
6797  dnnl::convert_to_c(aprop_kind), &data_desc.data,
6798  nullptr, epsilon, convert_to_c(flags)),
6799  "could not create a descriptor for a layer normalization "
6800  "forward propagation primitive");
6801  }
6802  };
6803 
6808  primitive_desc() = default;
6809 
6820  primitive_desc(const desc &adesc, const engine &aengine,
6821  bool allow_empty = false)
6822  : dnnl::primitive_desc(
6823  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
6824 
6836  primitive_desc(const desc &adesc, const primitive_attr &attr,
6837  const engine &aengine, bool allow_empty = false)
6838  : dnnl::primitive_desc(
6839  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
6840 
6848  : dnnl::primitive_desc(pd,
6849  dnnl::primitive::kind::layer_normalization,
6852 
6854  memory::desc src_desc() const { return base::src_desc(0); }
6855 
6857  memory::desc dst_desc() const { return base::dst_desc(0); }
6858 
6861 
6864 
6866  memory::desc mean_desc() const { return stat_desc(mean); }
6867 
6869  memory::desc variance_desc() const { return stat_desc(var); }
6870 
6871  private:
6872  enum {
6873  mean = 1,
6874  var = 2,
6875  };
6876  memory::desc stat_desc(int kind) const {
6881  &p),
6882  "could not retrieve a descriptor from a primitive "
6883  "descriptor for layer normalization forward propagation "
6884  "primitive");
6885  return query_md(p->flags & dnnl_use_global_stats ? query::src_md
6886  : query::dst_md,
6887  kind);
6888  }
6889  };
6890 
6893 
6898 };
6899 
6903  struct desc {
6905 
6919  desc(prop_kind aprop_kind, const memory::desc &diff_data_desc,
6920  const memory::desc &data_desc, const memory::desc &stat_desc,
6921  float epsilon, normalization_flags flags) {
6924  dnnl::convert_to_c(aprop_kind),
6925  &diff_data_desc.data, &data_desc.data,
6926  &stat_desc.data, epsilon, convert_to_c(flags)),
6927  "could not create a descriptor for a batch normalization "
6928  "backward propagation primitive");
6929  }
6930 
6943  desc(prop_kind aprop_kind, const memory::desc &diff_data_desc,
6944  const memory::desc &data_desc, float epsilon,
6945  normalization_flags flags) {
6947  dnnl::convert_to_c(aprop_kind),
6948  &diff_data_desc.data, &data_desc.data,
6949  nullptr, epsilon, convert_to_c(flags)),
6950  "could not create a descriptor for a batch normalization "
6951  "backward propagation primitive");
6952  }
6953  };
6954 
6959  primitive_desc() = default;
6960 
6974  primitive_desc(const desc &adesc, const engine &aengine,
6976  bool allow_empty = false)
6977  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
6978  hint_fwd_pd.get(), allow_empty) {}
6979 
6994  primitive_desc(const desc &adesc, const primitive_attr &attr,
6995  const engine &aengine,
6997  bool allow_empty = false)
6998  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
6999  hint_fwd_pd.get(), allow_empty) {}
7000 
7008  : dnnl::primitive_desc(pd,
7009  dnnl::primitive::kind::layer_normalization,
7011  }
7012 
7014  memory::desc src_desc() const { return base::src_desc(0); }
7015 
7018 
7020  memory::desc dst_desc() const { return base::dst_desc(0); }
7021 
7024 
7027 
7030  return base::diff_weights_desc(0);
7031  }
7032 
7035 
7038  return query_md(query::src_md, 2);
7039  }
7040 
7043  };
7044 
7047 
7052 };
7053 
7055 
7063 
7067  struct desc {
7069 
7084  desc(prop_kind aprop_kind, const memory::desc &src_desc,
7085  const memory::desc &weights_desc, const memory::desc &bias_desc,
7086  const memory::desc &dst_desc) {
7088  dnnl::convert_to_c(aprop_kind),
7089  &src_desc.data, &weights_desc.data,
7090  &bias_desc.data, &dst_desc.data),
7091  "could not create a descriptor for an inner product "
7092  "forward propagation primitive");
7093  }
7094 
7108  desc(prop_kind aprop_kind, const memory::desc &src_desc,
7109  const memory::desc &weights_desc,
7110  const memory::desc &dst_desc) {
7113  dnnl::convert_to_c(aprop_kind), &src_desc.data,
7114  &weights_desc.data, nullptr, &dst_desc.data),
7115  "could not create a descriptor for an inner product "
7116  "forward propagation primitive");
7117  }
7118  };
7119 
7123  primitive_desc() = default;
7124 
7135  primitive_desc(const desc &adesc, const engine &aengine,
7136  bool allow_empty = false)
7137  : dnnl::primitive_desc(
7138  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
7139 
7151  primitive_desc(const desc &adesc, const primitive_attr &attr,
7152  const engine &aengine, bool allow_empty = false)
7153  : dnnl::primitive_desc(
7154  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
7155 
7163  : dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product,
7166 
7168  memory::desc src_desc() const { return base::src_desc(0); }
7169 
7172 
7174  memory::desc dst_desc() const { return base::dst_desc(0); }
7175 
7178  };
7179 
7182 
7187 };
7188 
7192  struct desc {
7194 
7205  desc(const memory::desc &diff_src_desc,
7206  const memory::desc &weights_desc,
7207  const memory::desc &diff_dst_desc) {
7209  &diff_src_desc.data, &weights_desc.data,
7210  &diff_dst_desc.data),
7211  "could not create a descriptor for an inner product "
7212  "backward propagation primitive");
7213  }
7214  };
7215 
7220  primitive_desc() = default;
7221 
7235  primitive_desc(const desc &adesc, const engine &aengine,
7236  const inner_product_forward::primitive_desc &hint_fwd_pd,
7237  bool allow_empty = false)
7238  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
7239  hint_fwd_pd.get(), allow_empty) {}
7240 
7255  primitive_desc(const desc &adesc, const primitive_attr &attr,
7256  const engine &aengine,
7257  const inner_product_forward::primitive_desc &hint_fwd_pd,
7258  bool allow_empty = false)
7259  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
7260  hint_fwd_pd.get(), allow_empty) {}
7261 
7269  : dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product,
7271 
7274 
7277 
7280  };
7281 
7284 
7289 };
7290 
7294  struct desc {
7296 
7308  desc(const memory::desc &src_desc,
7309  const memory::desc &diff_weights_desc,
7310  const memory::desc &diff_bias_desc,
7311  const memory::desc &diff_dst_desc) {
7314  &src_desc.data, &diff_weights_desc.data,
7315  &diff_bias_desc.data, &diff_dst_desc.data),
7316  "could not create a descriptor for an inner product "
7317  "weights gradient primitive");
7318  }
7319 
7330  desc(const memory::desc &src_desc,
7331  const memory::desc &diff_weights_desc,
7332  const memory::desc &diff_dst_desc) {
7335  &src_desc.data, &diff_weights_desc.data, nullptr,
7336  &diff_dst_desc.data),
7337  "could not create a descriptor for an inner product "
7338  "weights gradient primitive");
7339  }
7340  };
7341 
7345  primitive_desc() = default;
7346 
7360  primitive_desc(const desc &adesc, const engine &aengine,
7361  const inner_product_forward::primitive_desc &hint_fwd_pd,
7362  bool allow_empty = false)
7363  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
7364  hint_fwd_pd.get(), allow_empty) {}
7365 
7380  primitive_desc(const desc &adesc, const primitive_attr &attr,
7381  const engine &aengine,
7382  const inner_product_forward::primitive_desc &hint_fwd_pd,
7383  bool allow_empty = false)
7384  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
7385  hint_fwd_pd.get(), allow_empty) {}
7386 
7394  : dnnl::primitive_desc(pd, dnnl::primitive::kind::inner_product,
7396 
7398  memory::desc src_desc() const { return base::src_desc(0); }
7399 
7402  return base::diff_weights_desc(0);
7403  }
7404 
7407 
7410  return base::diff_weights_desc(1);
7411  }
7412  };
7413 
7416 
7421 };
7422 
7424 
7432 
7435  using primitive_desc::primitive_desc;
7436 
7439 
7448  dnnl::prop_kind aprop_kind, dnnl::algorithm cell_kind)
7449  : rnn_primitive_desc_base(pd, aprop_kind, aprop_kind, cell_kind) {}
7450 
7455  }
7456 
7463  }
7464 
7469  }
7470 
7475  }
7476 
7481  }
7482 
7487  }
7488 
7493  }
7494 
7501  }
7502 
7507  }
7508 
7515  }
7516 
7521  }
7522 
7527  }
7528 
7535  }
7536 
7541  }
7542 
7547  }
7548 
7553  }
7554 
7558  return base::query_md(
7560  }
7561 
7565  return base::query_md(
7567  }
7568 
7575  }
7576 
7581  }
7582 
7589  }
7590 
7595  }
7596 
7597 protected:
7598  using rnn_base = rnn_primitive_desc_base;
7599 
7600  // (Deliberately not using doxygen comments)
7601  //
7602  // Constructs an RNN primitive descriptor base from a C API primitive
7603  // descriptor while checking that it actually describes the expected
7604  // primitive by comparing propagation and primitive kinds. Caller can
7605  // pass two options propagation kinds. This is typically used to check
7606  // that propagation kind is inference or training forward propagation.
7607  //
7608  // @param pd C API primitive descriptor.
7609  // @param prop_kind1 Expected propagation kind.
7610  // @param prop_kind2 Expected propagation kind.
7611  // @param cell_kind Expected cell kind.
7613  dnnl::prop_kind prop_kind1, dnnl::prop_kind prop_kind2,
7614  dnnl::algorithm cell_kind) {
7616  dnnl_status_t rc;
7617  rc = dnnl_primitive_desc_query(pd, dnnl_query_rnn_d, 0, &rnn_d);
7618  error::wrap_c_api(rc,
7619  "could not retrieve a descriptor from a primitive descriptor "
7620  "for an RNN primitive");
7621 
7622  dnnl_prop_kind_t c_prop_kind1 = convert_to_c(prop_kind1);
7623  dnnl_prop_kind_t c_prop_kind2 = convert_to_c(prop_kind2);
7624  dnnl_alg_kind_t c_cell_kind = convert_to_c(cell_kind);
7625 
7626  bool ok = rnn_d->primitive_kind == dnnl_rnn
7627  && (rnn_d->prop_kind == c_prop_kind1
7628  || rnn_d->prop_kind == c_prop_kind2)
7629  && rnn_d->cell_kind == c_cell_kind;
7630 
7631  if (!ok)
7632  DNNL_THROW_ERROR(dnnl_invalid_arguments,
7633  "mismatch between expected and provided descriptors for an "
7634  "RNN primitive");
7635 
7636  reset_with_clone(pd);
7637  }
7638 };
7639 
7643  struct desc {
7644  dnnl_rnn_desc_t data;
7645 
7686  desc(prop_kind aprop_kind, algorithm activation,
7687  rnn_direction direction, const memory::desc &src_layer_desc,
7688  const memory::desc &src_iter_desc,
7689  const memory::desc &weights_layer_desc,
7690  const memory::desc &weights_iter_desc,
7691  const memory::desc &bias_desc,
7692  const memory::desc &dst_layer_desc,
7693  const memory::desc &dst_iter_desc,
7694  rnn_flags flags = rnn_flags::undef, float alpha = 0.0f,
7695  float beta = 0.0f) {
7698  dnnl::convert_to_c(aprop_kind),
7699  dnnl::convert_to_c(activation),
7700  dnnl::convert_to_c(direction), &src_layer_desc.data,
7701  &src_iter_desc.data, &weights_layer_desc.data,
7702  &weights_iter_desc.data, &bias_desc.data,
7703  &dst_layer_desc.data, &dst_iter_desc.data,
7704  dnnl::convert_to_c(flags), alpha, beta),
7705  "could not create a descriptor for a vanilla RNN forward "
7706  "propagation primitive");
7707  }
7708  };
7709 
7713  primitive_desc() = default;
7714 
7725  primitive_desc(const desc &adesc, const engine &aengine,
7726  bool allow_empty = false)
7728  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
7729 
7741  primitive_desc(const desc &adesc, const primitive_attr &attr,
7742  const engine &aengine, bool allow_empty = false)
7744  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
7745 
7755  dnnl::algorithm::vanilla_rnn) {}
7756 
7759  return rnn_base::src_layer_desc();
7760  }
7761 
7764 
7768  }
7769 
7772  return rnn_base::weights_iter_desc();
7773  }
7774 
7777 
7780  return rnn_base::dst_layer_desc();
7781  }
7782 
7785 
7788  return rnn_base::workspace_desc();
7789  }
7790  };
7791 
7793  vanilla_rnn_forward() = default;
7794 
7799 };
7800 
7804  struct desc {
7805  dnnl_rnn_desc_t data;
7806 
7859  desc(prop_kind aprop_kind, algorithm activation,
7860  rnn_direction direction, const memory::desc &src_layer_desc,
7861  const memory::desc &src_iter_desc,
7862  const memory::desc &weights_layer_desc,
7863  const memory::desc &weights_iter_desc,
7864  const memory::desc &bias_desc,
7865  const memory::desc &dst_layer_desc,
7866  const memory::desc &dst_iter_desc,
7867  const memory::desc &diff_src_layer_desc,
7868  const memory::desc &diff_src_iter_desc,
7869  const memory::desc &diff_weights_layer_desc,
7870  const memory::desc &diff_weights_iter_desc,
7871  const memory::desc &diff_bias_desc,
7872  const memory::desc &diff_dst_layer_desc,
7873  const memory::desc &diff_dst_iter_desc,
7874  rnn_flags flags = rnn_flags::undef, float alpha = 0.0f,
7875  float beta = 0.0f) {
7878  dnnl::convert_to_c(aprop_kind),
7879  dnnl::convert_to_c(activation),
7880  dnnl::convert_to_c(direction), &src_layer_desc.data,
7881  &src_iter_desc.data, &weights_layer_desc.data,
7882  &weights_iter_desc.data, &bias_desc.data,
7883  &dst_layer_desc.data, &dst_iter_desc.data,
7884  &diff_src_layer_desc.data, &diff_src_iter_desc.data,
7885  &diff_weights_layer_desc.data,
7886  &diff_weights_iter_desc.data, &diff_bias_desc.data,
7887  &diff_dst_layer_desc.data, &diff_dst_iter_desc.data,
7888  dnnl::convert_to_c(flags), alpha, beta),
7889  "could not create a descriptor for a vanilla RNN backward "
7890  "propagation primitive");
7891  }
7892  };
7893 
7897  primitive_desc() = default;
7898 
7912  primitive_desc(const desc &adesc, const engine &aengine,
7913  const vanilla_rnn_forward::primitive_desc &hint_fwd_pd,
7914  bool allow_empty = false)
7915  : rnn_primitive_desc_base(&adesc.data, nullptr, aengine,
7916  hint_fwd_pd.get(), allow_empty) {}
7917 
7932  primitive_desc(const desc &adesc, const primitive_attr &attr,
7933  const engine &aengine,
7934  const vanilla_rnn_forward::primitive_desc &hint_fwd_pd,
7935  bool allow_empty = false)
7936  : rnn_primitive_desc_base(&adesc.data, &attr, aengine,
7937  hint_fwd_pd.get(), allow_empty) {}
7938 
7947  dnnl::algorithm::vanilla_rnn) {}
7948 
7951  return rnn_base::src_layer_desc();
7952  }
7953 
7956 
7960  }
7961 
7964  return rnn_base::weights_iter_desc();
7965  }
7966 
7969 
7972  return rnn_base::dst_layer_desc();
7973  }
7974 
7977 
7980  return rnn_base::workspace_desc();
7981  }
7982 
7986  }
7987 
7991  }
7992 
7996  }
7997 
8001  }
8002 
8005  return rnn_base::diff_bias_desc();
8006  }
8007 
8011  }
8012 
8016  }
8017  };
8018 
8021 
8026 };
8027 
8029 struct lstm_forward : public primitive {
8031  struct desc {
8032  dnnl_rnn_desc_t data;
8033 
8082  desc(prop_kind aprop_kind, rnn_direction direction,
8083  const memory::desc &src_layer_desc,
8084  const memory::desc &src_iter_desc,
8085  const memory::desc &src_iter_c_desc,
8086  const memory::desc &weights_layer_desc,
8087  const memory::desc &weights_iter_desc,
8088  const memory::desc &weights_peephole_desc,
8089  const memory::desc &weights_projection_desc,
8090  const memory::desc &bias_desc,
8091  const memory::desc &dst_layer_desc,
8092  const memory::desc &dst_iter_desc,
8093  const memory::desc &dst_iter_c_desc,
8094  rnn_flags flags = rnn_flags::undef) {
8097  dnnl::convert_to_c(aprop_kind),
8098  dnnl::convert_to_c(direction), &src_layer_desc.data,
8099  &src_iter_desc.data, &src_iter_c_desc.data,
8100  &weights_layer_desc.data, &weights_iter_desc.data,
8101  &weights_peephole_desc.data,
8102  &weights_projection_desc.data, &bias_desc.data,
8103  &dst_layer_desc.data, &dst_iter_desc.data,
8104  &dst_iter_c_desc.data, dnnl::convert_to_c(flags)),
8105  "could not create a descriptor for an LSTM forward "
8106  "propagation primitive");
8107  }
8108 
8150  desc(prop_kind aprop_kind, rnn_direction direction,
8151  const memory::desc &src_layer_desc,
8152  const memory::desc &src_iter_desc,
8153  const memory::desc &src_iter_c_desc,
8154  const memory::desc &weights_layer_desc,
8155  const memory::desc &weights_iter_desc,
8156  const memory::desc &weights_peephole_desc,
8157  const memory::desc &bias_desc,
8158  const memory::desc &dst_layer_desc,
8159  const memory::desc &dst_iter_desc,
8160  const memory::desc &dst_iter_c_desc,
8161  rnn_flags flags = rnn_flags::undef) {
8164  dnnl::convert_to_c(aprop_kind),
8165  dnnl::convert_to_c(direction), &src_layer_desc.data,
8166  &src_iter_desc.data, &src_iter_c_desc.data,
8167  &weights_layer_desc.data, &weights_iter_desc.data,
8168  &weights_peephole_desc.data, &bias_desc.data,
8169  &dst_layer_desc.data, &dst_iter_desc.data,
8170  &dst_iter_c_desc.data, dnnl::convert_to_c(flags)),
8171  "could not create a descriptor for an LSTM forward "
8172  "propagation primitive");
8173  }
8174 
8211  desc(prop_kind aprop_kind, rnn_direction direction,
8212  const memory::desc &src_layer_desc,
8213  const memory::desc &src_iter_desc,
8214  const memory::desc &src_iter_c_desc,
8215  const memory::desc &weights_layer_desc,
8216  const memory::desc &weights_iter_desc,
8217  const memory::desc &bias_desc,
8218  const memory::desc &dst_layer_desc,
8219  const memory::desc &dst_iter_desc,
8220  const memory::desc &dst_iter_c_desc,
8221  rnn_flags flags = rnn_flags::undef) {
8224  dnnl::convert_to_c(aprop_kind),
8225  dnnl::convert_to_c(direction), &src_layer_desc.data,
8226  &src_iter_desc.data, &src_iter_c_desc.data,
8227  &weights_layer_desc.data, &weights_iter_desc.data,
8228  &bias_desc.data, &dst_layer_desc.data,
8229  &dst_iter_desc.data, &dst_iter_c_desc.data,
8230  dnnl::convert_to_c(flags)),
8231  "could not create a descriptor for an LSTM forward "
8232  "propagation primitive");
8233  }
8234  };
8235 
8239  primitive_desc() = default;
8240 
8250  primitive_desc(const desc &adesc, const engine &aengine,
8251  bool allow_empty = false)
8253  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
8254 
8265  primitive_desc(const desc &adesc, const primitive_attr &attr,
8266  const engine &aengine, bool allow_empty = false)
8268  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
8269 
8280 
8283  return rnn_base::src_layer_desc();
8284  }
8285 
8288 
8291  return rnn_base::src_iter_c_desc();
8292  }
8293 
8297  }
8298 
8301  return rnn_base::weights_iter_desc();
8302  }
8303 
8307  }
8308 
8312  }
8313 
8316 
8319  return rnn_base::dst_layer_desc();
8320  }
8321 
8324 
8327  return rnn_base::dst_iter_c_desc();
8328  }
8329 
8332  return rnn_base::workspace_desc();
8333  }
8334  };
8335 
8337  lstm_forward() = default;
8338 
8343 };
8344 
8346 struct lstm_backward : public primitive {
8348  struct desc {
8349  dnnl_rnn_desc_t data;
8350 
8426  desc(prop_kind aprop_kind, rnn_direction direction,
8427  const memory::desc &src_layer_desc,
8428  const memory::desc &src_iter_desc,
8429  const memory::desc &src_iter_c_desc,
8430  const memory::desc &weights_layer_desc,
8431  const memory::desc &weights_iter_desc,
8432  const memory::desc &weights_peephole_desc,
8433  const memory::desc &weights_projection_desc,
8434  const memory::desc &bias_desc,
8435  const memory::desc &dst_layer_desc,
8436  const memory::desc &dst_iter_desc,
8437  const memory::desc &dst_iter_c_desc,
8438  const memory::desc &diff_src_layer_desc,
8439  const memory::desc &diff_src_iter_desc,
8440  const memory::desc &diff_src_iter_c_desc,
8441  const memory::desc &diff_weights_layer_desc,
8442  const memory::desc &diff_weights_iter_desc,
8443  const memory::desc &diff_weights_peephole_desc,
8444  const memory::desc &diff_weights_projection_desc,
8445  const memory::desc &diff_bias_desc,
8446  const memory::desc &diff_dst_layer_desc,
8447  const memory::desc &diff_dst_iter_desc,
8448  const memory::desc &diff_dst_iter_c_desc,
8449  rnn_flags flags = rnn_flags::undef) {
8452  dnnl::convert_to_c(aprop_kind),
8453  dnnl::convert_to_c(direction), &src_layer_desc.data,
8454  &src_iter_desc.data, &src_iter_c_desc.data,
8455  &weights_layer_desc.data, &weights_iter_desc.data,
8456  &weights_peephole_desc.data,
8457  &weights_projection_desc.data, &bias_desc.data,
8458  &dst_layer_desc.data, &dst_iter_desc.data,
8459  &dst_iter_c_desc.data, &diff_src_layer_desc.data,
8460  &diff_src_iter_desc.data,
8461  &diff_src_iter_c_desc.data,
8462  &diff_weights_layer_desc.data,
8463  &diff_weights_iter_desc.data,
8464  &diff_weights_peephole_desc.data,
8465  &diff_weights_projection_desc.data,
8466  &diff_bias_desc.data, &diff_dst_layer_desc.data,
8467  &diff_dst_iter_desc.data,
8468  &diff_dst_iter_c_desc.data,
8469  dnnl::convert_to_c(flags)),
8470  "could not create a descriptor for an LSTM backward "
8471  "propagation primitive");
8472  }
8473 
8538  desc(prop_kind aprop_kind, rnn_direction direction,
8539  const memory::desc &src_layer_desc,
8540  const memory::desc &src_iter_desc,
8541  const memory::desc &src_iter_c_desc,
8542  const memory::desc &weights_layer_desc,
8543  const memory::desc &weights_iter_desc,
8544  const memory::desc &weights_peephole_desc,
8545  const memory::desc &bias_desc,
8546  const memory::desc &dst_layer_desc,
8547  const memory::desc &dst_iter_desc,
8548  const memory::desc &dst_iter_c_desc,
8549  const memory::desc &diff_src_layer_desc,
8550  const memory::desc &diff_src_iter_desc,
8551  const memory::desc &diff_src_iter_c_desc,
8552  const memory::desc &diff_weights_layer_desc,
8553  const memory::desc &diff_weights_iter_desc,
8554  const memory::desc &diff_weights_peephole_desc,
8555  const memory::desc &diff_bias_desc,
8556  const memory::desc &diff_dst_layer_desc,
8557  const memory::desc &diff_dst_iter_desc,
8558  const memory::desc &diff_dst_iter_c_desc,
8559  rnn_flags flags = rnn_flags::undef) {
8562  dnnl::convert_to_c(aprop_kind),
8563  dnnl::convert_to_c(direction), &src_layer_desc.data,
8564  &src_iter_desc.data, &src_iter_c_desc.data,
8565  &weights_layer_desc.data, &weights_iter_desc.data,
8566  &weights_peephole_desc.data, &bias_desc.data,
8567  &dst_layer_desc.data, &dst_iter_desc.data,
8568  &dst_iter_c_desc.data, &diff_src_layer_desc.data,
8569  &diff_src_iter_desc.data,
8570  &diff_src_iter_c_desc.data,
8571  &diff_weights_layer_desc.data,
8572  &diff_weights_iter_desc.data,
8573  &diff_weights_peephole_desc.data,
8574  &diff_bias_desc.data, &diff_dst_layer_desc.data,
8575  &diff_dst_iter_desc.data,
8576  &diff_dst_iter_c_desc.data,
8577  dnnl::convert_to_c(flags)),
8578  "could not create a descriptor for an LSTM backward "
8579  "propagation primitive");
8580  }
8581 
8637  desc(prop_kind aprop_kind, rnn_direction direction,
8638  const memory::desc &src_layer_desc,
8639  const memory::desc &src_iter_desc,
8640  const memory::desc &src_iter_c_desc,
8641  const memory::desc &weights_layer_desc,
8642  const memory::desc &weights_iter_desc,
8643  const memory::desc &bias_desc,
8644  const memory::desc &dst_layer_desc,
8645  const memory::desc &dst_iter_desc,
8646  const memory::desc &dst_iter_c_desc,
8647  const memory::desc &diff_src_layer_desc,
8648  const memory::desc &diff_src_iter_desc,
8649  const memory::desc &diff_src_iter_c_desc,
8650  const memory::desc &diff_weights_layer_desc,
8651  const memory::desc &diff_weights_iter_desc,
8652  const memory::desc &diff_bias_desc,
8653  const memory::desc &diff_dst_layer_desc,
8654  const memory::desc &diff_dst_iter_desc,
8655  const memory::desc &diff_dst_iter_c_desc,
8656  rnn_flags flags = rnn_flags::undef) {
8659  dnnl::convert_to_c(aprop_kind),
8660  dnnl::convert_to_c(direction), &src_layer_desc.data,
8661  &src_iter_desc.data, &src_iter_c_desc.data,
8662  &weights_layer_desc.data, &weights_iter_desc.data,
8663  &bias_desc.data, &dst_layer_desc.data,
8664  &dst_iter_desc.data, &dst_iter_c_desc.data,
8665  &diff_src_layer_desc.data, &diff_src_iter_desc.data,
8666  &diff_src_iter_c_desc.data,
8667  &diff_weights_layer_desc.data,
8668  &diff_weights_iter_desc.data, &diff_bias_desc.data,
8669  &diff_dst_layer_desc.data, &diff_dst_iter_desc.data,
8670  &diff_dst_iter_c_desc.data,
8671  dnnl::convert_to_c(flags)),
8672  "could not create a descriptor for an LSTM backward "
8673  "propagation primitive");
8674  }
8675  };
8676 
8680  primitive_desc() = default;
8681 
8694  primitive_desc(const desc &adesc, const engine &aengine,
8695  const lstm_forward::primitive_desc &hint_fwd_pd,
8696  bool allow_empty = false)
8697  : rnn_primitive_desc_base(&adesc.data, nullptr, aengine,
8698  hint_fwd_pd.get(), allow_empty) {}
8699 
8713  primitive_desc(const desc &adesc, const primitive_attr &attr,
8714  const engine &aengine,
8715  const lstm_forward::primitive_desc &hint_fwd_pd,
8716  bool allow_empty = false)
8717  : rnn_primitive_desc_base(&adesc.data, &attr, aengine,
8718  hint_fwd_pd.get(), allow_empty) {}
8719 
8729 
8732  return rnn_base::src_layer_desc();
8733  }
8734 
8737 
8740  return rnn_base::src_iter_c_desc();
8741  }
8742 
8746  }
8747 
8750  return rnn_base::weights_iter_desc();
8751  }
8752 
8756  }
8757 
8761  }
8762 
8765 
8768  return rnn_base::dst_layer_desc();
8769  }
8770 
8773 
8776  return rnn_base::dst_iter_c_desc();
8777  }
8778 
8781  return rnn_base::workspace_desc();
8782  }
8783 
8787  }
8788 
8792  }
8793 
8797  }
8798 
8802  }
8803 
8807  }
8808 
8812  }
8813 
8817  }
8818 
8821  return rnn_base::diff_bias_desc();
8822  }
8823 
8827  }
8828 
8832  }
8833 
8837  }
8838  };
8839 
8841  lstm_backward() = default;
8842 
8847 };
8848 
8850 struct gru_forward : public primitive {
8852  struct desc {
8853  dnnl_rnn_desc_t data;
8854 
8887  desc(prop_kind aprop_kind, rnn_direction direction,
8888  const memory::desc &src_layer_desc,
8889  const memory::desc &src_iter_desc,
8890  const memory::desc &weights_layer_desc,
8891  const memory::desc &weights_iter_desc,
8892  const memory::desc &bias_desc,
8893  const memory::desc &dst_layer_desc,
8894  const memory::desc &dst_iter_desc,
8895  rnn_flags flags = rnn_flags::undef) {
8898  dnnl::convert_to_c(aprop_kind),
8899  dnnl::convert_to_c(direction), &src_layer_desc.data,
8900  &src_iter_desc.data, &weights_layer_desc.data,
8901  &weights_iter_desc.data, &bias_desc.data,
8902  &dst_layer_desc.data, &dst_iter_desc.data,
8903  dnnl::convert_to_c(flags)),
8904  "could not create a descriptor for a GRU forward "
8905  "propagation primitive");
8906  }
8907  };
8908 
8912  primitive_desc() = default;
8913 
8923  primitive_desc(const desc &adesc, const engine &aengine,
8924  bool allow_empty = false)
8926  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
8927 
8938  primitive_desc(const desc &adesc, const primitive_attr &attr,
8939  const engine &aengine, bool allow_empty = false)
8941  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
8942 
8952  dnnl::algorithm::vanilla_gru) {}
8953 
8956  return rnn_base::src_layer_desc();
8957  }
8958 
8961 
8965  }
8966 
8969  return rnn_base::weights_iter_desc();
8970  }
8971 
8974 
8977  return rnn_base::dst_layer_desc();
8978  }
8979 
8982 
8985  return rnn_base::workspace_desc();
8986  }
8987  };
8988 
8990  gru_forward() = default;
8991 
8996 };
8997 
8999 struct gru_backward : public primitive {
9001  struct desc {
9002  dnnl_rnn_desc_t data;
9003 
9048  desc(prop_kind aprop_kind, rnn_direction direction,
9049  const memory::desc &src_layer_desc,
9050  const memory::desc &src_iter_desc,
9051  const memory::desc &weights_layer_desc,
9052  const memory::desc &weights_iter_desc,
9053  const memory::desc &bias_desc,
9054  const memory::desc &dst_layer_desc,
9055  const memory::desc &dst_iter_desc,
9056  const memory::desc &diff_src_layer_desc,
9057  const memory::desc &diff_src_iter_desc,
9058  const memory::desc &diff_weights_layer_desc,
9059  const memory::desc &diff_weights_iter_desc,
9060  const memory::desc &diff_bias_desc,
9061  const memory::desc &diff_dst_layer_desc,
9062  const memory::desc &diff_dst_iter_desc,
9063  rnn_flags flags = rnn_flags::undef) {
9066  dnnl::convert_to_c(aprop_kind),
9067  dnnl::convert_to_c(direction), &src_layer_desc.data,
9068  &src_iter_desc.data, &weights_layer_desc.data,
9069  &weights_iter_desc.data, &bias_desc.data,
9070  &dst_layer_desc.data, &dst_iter_desc.data,
9071  &diff_src_layer_desc.data, &diff_src_iter_desc.data,
9072  &diff_weights_layer_desc.data,
9073  &diff_weights_iter_desc.data, &diff_bias_desc.data,
9074  &diff_dst_layer_desc.data, &diff_dst_iter_desc.data,
9075  dnnl::convert_to_c(flags)),
9076  "could not create a descriptor for a GRU backward "
9077  "propagation primitive");
9078  }
9079  };
9080 
9084  primitive_desc() = default;
9085 
9098  primitive_desc(const desc &adesc, const engine &aengine,
9099  const gru_forward::primitive_desc &hint_fwd_pd,
9100  bool allow_empty = false)
9101  : rnn_primitive_desc_base(&adesc.data, nullptr, aengine,
9102  hint_fwd_pd.get(), allow_empty) {}
9103 
9117  primitive_desc(const desc &adesc, const primitive_attr &attr,
9118  const engine &aengine,
9119  const gru_forward::primitive_desc &hint_fwd_pd,
9120  bool allow_empty = false)
9121  : rnn_primitive_desc_base(&adesc.data, &attr, aengine,
9122  hint_fwd_pd.get(), allow_empty) {}
9123 
9132  dnnl::algorithm::vanilla_gru) {}
9133 
9136  return rnn_base::src_layer_desc();
9137  }
9138 
9141 
9145  }
9146 
9149  return rnn_base::weights_iter_desc();
9150  }
9151 
9154 
9157  return rnn_base::dst_layer_desc();
9158  }
9159 
9162 
9165  return rnn_base::workspace_desc();
9166  }
9167 
9171  }
9172 
9176  }
9177 
9181  }
9182 
9186  }
9187 
9190  return rnn_base::diff_bias_desc();
9191  }
9192 
9196  }
9197 
9201  }
9202  };
9203 
9205  gru_backward() = default;
9206 
9211 };
9212 
9214 struct lbr_gru_forward : public primitive {
9216  struct desc {
9217  dnnl_rnn_desc_t data;
9218 
9252  desc(prop_kind aprop_kind, rnn_direction direction,
9253  const memory::desc &src_layer_desc,
9254  const memory::desc &src_iter_desc,
9255  const memory::desc &weights_layer_desc,
9256  const memory::desc &weights_iter_desc,
9257  const memory::desc &bias_desc,
9258  const memory::desc &dst_layer_desc,
9259  const memory::desc &dst_iter_desc,
9260  rnn_flags flags = rnn_flags::undef) {
9263  dnnl::convert_to_c(aprop_kind),
9264  dnnl::convert_to_c(direction), &src_layer_desc.data,
9265  &src_iter_desc.data, &weights_layer_desc.data,
9266  &weights_iter_desc.data, &bias_desc.data,
9267  &dst_layer_desc.data, &dst_iter_desc.data,
9268  dnnl::convert_to_c(flags)),
9269  "could not create a descriptor for an LBR GRU forward "
9270  "propagation primitive");
9271  }
9272  };
9273 
9277  primitive_desc() = default;
9278 
9289  primitive_desc(const desc &adesc, const engine &aengine,
9290  bool allow_empty = false)
9292  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
9293 
9305  primitive_desc(const desc &adesc, const primitive_attr &attr,
9306  const engine &aengine, bool allow_empty = false)
9308  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
9309 
9319  dnnl::algorithm::lbr_gru) {}
9320 
9323  return rnn_base::src_layer_desc();
9324  }
9325 
9328 
9332  }
9333 
9336  return rnn_base::weights_iter_desc();
9337  }
9338 
9341 
9344  return rnn_base::dst_layer_desc();
9345  }
9346 
9349 
9352  return rnn_base::workspace_desc();
9353  }
9354  };
9355 
9357  lbr_gru_forward() = default;
9358 
9363 };
9364 
9366 struct lbr_gru_backward : public primitive {
9368  struct desc {
9369  dnnl_rnn_desc_t data;
9370 
9416  desc(prop_kind aprop_kind, rnn_direction direction,
9417  const memory::desc &src_layer_desc,
9418  const memory::desc &src_iter_desc,
9419  const memory::desc &weights_layer_desc,
9420  const memory::desc &weights_iter_desc,
9421  const memory::desc &bias_desc,
9422  const memory::desc &dst_layer_desc,
9423  const memory::desc &dst_iter_desc,
9424  const memory::desc &diff_src_layer_desc,
9425  const memory::desc &diff_src_iter_desc,
9426  const memory::desc &diff_weights_layer_desc,
9427  const memory::desc &diff_weights_iter_desc,
9428  const memory::desc &diff_bias_desc,
9429  const memory::desc &diff_dst_layer_desc,
9430  const memory::desc &diff_dst_iter_desc,
9431  rnn_flags flags = rnn_flags::undef) {
9434  dnnl::convert_to_c(aprop_kind),
9435  dnnl::convert_to_c(direction), &src_layer_desc.data,
9436  &src_iter_desc.data, &weights_layer_desc.data,
9437  &weights_iter_desc.data, &bias_desc.data,
9438  &dst_layer_desc.data, &dst_iter_desc.data,
9439  &diff_src_layer_desc.data, &diff_src_iter_desc.data,
9440  &diff_weights_layer_desc.data,
9441  &diff_weights_iter_desc.data, &diff_bias_desc.data,
9442  &diff_dst_layer_desc.data, &diff_dst_iter_desc.data,
9443  dnnl::convert_to_c(flags)),
9444  "could not create a descriptor for an LBR GRU backward "
9445  "propagation primitive");
9446  }
9447  };
9448 
9452  primitive_desc() = default;
9453 
9467  primitive_desc(const desc &adesc, const engine &aengine,
9468  const lbr_gru_forward::primitive_desc &hint_fwd_pd,
9469  bool allow_empty = false)
9470  : rnn_primitive_desc_base(&adesc.data, nullptr, aengine,
9471  hint_fwd_pd.get(), allow_empty) {}
9472 
9487  primitive_desc(const desc &adesc, const primitive_attr &attr,
9488  const engine &aengine,
9489  const lbr_gru_forward::primitive_desc &hint_fwd_pd,
9490  bool allow_empty = false)
9491  : rnn_primitive_desc_base(&adesc.data, &attr, aengine,
9492  hint_fwd_pd.get(), allow_empty) {}
9493 
9503 
9506  return rnn_base::src_layer_desc();
9507  }
9508 
9511 
9515  }
9516 
9519  return rnn_base::weights_iter_desc();
9520  }
9521 
9524 
9527  return rnn_base::dst_layer_desc();
9528  }
9529 
9532 
9535  return rnn_base::workspace_desc();
9536  }
9537 
9541  }
9542 
9546  }
9547 
9551  }
9552 
9556  }
9557 
9560  return rnn_base::diff_bias_desc();
9561  }
9562 
9566  }
9567 
9571  }
9572  };
9573 
9575  lbr_gru_backward() = default;
9576 
9581 };
9582 
9584 
9592 
9594 struct shuffle_forward : public primitive {
9596  struct desc {
9597  dnnl_shuffle_desc_t data;
9598 
9608  desc(prop_kind aprop_kind, const memory::desc &data_desc, int axis,
9609  int group_size) {
9611  dnnl::convert_to_c(aprop_kind),
9612  &data_desc.data, axis, group_size),
9613  "could not create a descriptor for a shuffle forward "
9614  "propagation primitive");
9615  }
9616  };
9617 
9621  primitive_desc() = default;
9622 
9634  primitive_desc(const desc &adesc, const engine &aengine,
9635  const primitive_attr &attr = primitive_attr(),
9636  bool allow_empty = false)
9637  : dnnl::primitive_desc(
9638  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
9639 
9647  : dnnl::primitive_desc(pd, dnnl::primitive::kind::shuffle,
9650 
9652  memory::desc src_desc() const { return base::src_desc(0); }
9653 
9655  memory::desc dst_desc() const { return base::dst_desc(0); }
9656  };
9657 
9659  shuffle_forward() = default;
9660 
9665 };
9666 
9668 struct shuffle_backward : public primitive {
9671  struct desc {
9672  dnnl_shuffle_desc_t data;
9673 
9681  desc(const memory::desc &diff_data_desc, int axis, int group_size) {
9683  &diff_data_desc.data, axis, group_size),
9684  "could not create a descriptor for a shuffle backward "
9685  "propagation primitive");
9686  }
9687  };
9688 
9692  primitive_desc() = default;
9693 
9708  primitive_desc(const desc &adesc, const engine &aengine,
9709  const shuffle_forward::primitive_desc &hint_fwd_pd,
9710  const primitive_attr &attr = primitive_attr(),
9711  bool allow_empty = false)
9712  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
9713  hint_fwd_pd.get(), allow_empty) {}
9714 
9722  : dnnl::primitive_desc(pd, dnnl::primitive::kind::shuffle,
9724 
9727 
9730  };
9731 
9733  shuffle_backward() = default;
9734 
9739 };
9740 
9742 
9750 
9752 struct binary : public primitive {
9754  struct desc {
9757 
9759  desc() = default;
9760 
9768  desc(algorithm aalgorithm, const memory::desc &src0,
9769  const memory::desc &src1, const memory::desc &dst) {
9772  &src0.data, &src1.data, &dst.data),
9773  "could not create a descriptor for a binary operation "
9774  "primitive");
9775  }
9776  };
9777 
9781  primitive_desc() = default;
9782 
9792  primitive_desc(const desc &adesc, const engine &aengine,
9793  bool allow_empty = false)
9794  : dnnl::primitive_desc(
9795  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
9796 
9807  primitive_desc(const desc &adesc, const primitive_attr &attr,
9808  const engine &aengine, bool allow_empty = false)
9809  : dnnl::primitive_desc(
9810  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
9811 
9818 
9820  memory::desc src_desc(int idx = 0) const { return base::src_desc(idx); }
9821 
9823  memory::desc src0_desc() const { return base::src_desc(0); }
9824 
9826  memory::desc src1_desc() const { return base::src_desc(1); }
9827 
9829  memory::desc dst_desc() const { return base::dst_desc(0); }
9830  };
9831 
9833  binary() = default;
9834 
9838  binary(const primitive_desc &pd) : primitive(pd) {}
9839 };
9840 
9842 
9852 
9854 struct matmul : public primitive {
9856  struct desc {
9857  dnnl_matmul_desc_t data;
9858 
9864  desc(const memory::desc &src_desc, const memory::desc &weights_desc,
9865  const memory::desc &dst_desc) {
9867  dnnl_matmul_desc_init(&data, &src_desc.data,
9868  &weights_desc.data, nullptr, &dst_desc.data),
9869  "could not create a descriptor for a matmul primitive");
9870  }
9871 
9878  desc(const memory::desc &src_desc, const memory::desc &weights_desc,
9879  const memory::desc &bias_desc, const memory::desc &dst_desc) {
9880  error::wrap_c_api(dnnl_matmul_desc_init(&data, &src_desc.data,
9881  &weights_desc.data, &bias_desc.data,
9882  &dst_desc.data),
9883  "could not create a descriptor for a matmul primitive");
9884  }
9885  };
9886 
9890  primitive_desc() = default;
9891 
9900  primitive_desc(const desc &adesc, const engine &aengine,
9901  bool allow_empty = false)
9902  : dnnl::primitive_desc(
9903  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
9904 
9914  primitive_desc(const desc &adesc, const primitive_attr &attr,
9915  const engine &aengine, bool allow_empty = false)
9916  : dnnl::primitive_desc(
9917  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
9918 
9925 
9928 
9931  return query_md(query::weights_md, 0);
9932  }
9933 
9936  return query_md(query::weights_md, 1);
9937  }
9938 
9941  };
9942 
9944  matmul() = default;
9945 
9948  matmul(const primitive_desc &pd) : primitive(pd) {}
9949 };
9950 
9952 
9962 
9966  struct desc {
9968 
9984  desc(prop_kind aprop_kind, algorithm aalgorithm,
9985  const memory::desc &src_desc, const memory::desc &dst_desc) {
9987  dnnl::convert_to_c(aprop_kind),
9988  convert_to_c(aalgorithm), nullptr,
9989  &src_desc.data, &dst_desc.data),
9990  "could not create a resampling forward descriptor");
9991  }
9992 
10004  desc(prop_kind aprop_kind, algorithm aalgorithm,
10005  const std::vector<float> &factors,
10006  const memory::desc &src_desc) {
10007  memory::validate_dims(factors, src_desc.data.ndims - 2);
10009  dnnl::convert_to_c(aprop_kind),
10010  convert_to_c(aalgorithm), &factors[0],
10011  &src_desc.data, nullptr),
10012  "could not create a resampling forward descriptor");
10013  }
10014 
10031  desc(prop_kind aprop_kind, algorithm aalgorithm,
10032  const std::vector<float> &factors, const memory::desc &src_desc,
10033  const memory::desc &dst_desc) {
10034  if (!factors.empty())
10035  memory::validate_dims(factors, src_desc.data.ndims - 2);
10037  dnnl::convert_to_c(aprop_kind),
10038  convert_to_c(aalgorithm), factors.data(),
10039  &src_desc.data, &dst_desc.data),
10040  "could not create a resampling forward descriptor");
10041  }
10042  };
10043 
10047  primitive_desc() = default;
10048 
10059  primitive_desc(const desc &adesc, const engine &aengine,
10060  bool allow_empty = false)
10061  : dnnl::primitive_desc(
10062  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
10063 
10075  primitive_desc(const desc &adesc, const primitive_attr &attr,
10076  const engine &aengine, bool allow_empty = false)
10077  : dnnl::primitive_desc(
10078  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
10079 
10087  : dnnl::primitive_desc(pd, dnnl::primitive::kind::resampling,
10090 
10092  memory::desc src_desc() const { return base::src_desc(0); }
10093 
10095  memory::desc dst_desc() const { return base::dst_desc(0); }
10096  };
10097 
10099  resampling_forward() = default;
10100 
10105 };
10106 
10110  struct desc {
10112 
10121  desc(algorithm aalgorithm, const memory::desc &diff_src_desc,
10122  const memory::desc &diff_dst_desc) {
10124  convert_to_c(aalgorithm), nullptr,
10125  &diff_src_desc.data, &diff_dst_desc.data),
10126  "could not create a resampling backward data descriptor");
10127  }
10128 
10138  desc(algorithm aalgorithm, const std::vector<float> &factors,
10139  const memory::desc &diff_src_desc,
10140  const memory::desc &diff_dst_desc) {
10141  if (!factors.empty())
10142  memory::validate_dims(factors, diff_src_desc.data.ndims - 2);
10144  convert_to_c(aalgorithm), factors.data(),
10145  &diff_src_desc.data, &diff_dst_desc.data),
10146  "could not create a resampling backward data descriptor");
10147  }
10148  };
10149 
10153  primitive_desc() = default;
10154 
10168  primitive_desc(const desc &adesc, const engine &aengine,
10169  const resampling_forward::primitive_desc &hint_fwd_pd,
10170  bool allow_empty = false)
10171  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
10172  hint_fwd_pd.get(), allow_empty) {}
10173 
10188  primitive_desc(const desc &adesc, const primitive_attr &attr,
10189  const engine &aengine,
10190  const resampling_forward::primitive_desc &hint_fwd_pd,
10191  bool allow_empty = false)
10192  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
10193  hint_fwd_pd.get(), allow_empty) {}
10194 
10202  : dnnl::primitive_desc(pd, dnnl::primitive::kind::resampling,
10204 
10207 
10210  };
10211 
10213  resampling_backward() = default;
10214 
10219 };
10220 
10222 
10230 
10234  struct desc {
10236 
10263  desc(prop_kind aprop_kind, algorithm aalgorithm,
10264  const memory::desc &src_desc, const memory::desc &dst_desc,
10265  const memory::dims &strides, const memory::dims &kernel,
10266  const memory::dims &dilation, const memory::dims &padding_l,
10267  const memory::dims &padding_r) {
10268  memory::validate_dims(strides, src_desc.data.ndims - 2);
10269  memory::validate_dims(kernel, src_desc.data.ndims - 2);
10270  memory::validate_dims(padding_l, src_desc.data.ndims - 2);
10271  memory::validate_dims(padding_r, src_desc.data.ndims - 2);
10272  memory::validate_dims(dilation, src_desc.data.ndims - 2);
10275  dnnl::convert_to_c(aprop_kind),
10276  convert_to_c(aalgorithm), &src_desc.data,
10277  &dst_desc.data, &strides[0], &kernel[0],
10278  &dilation[0], &padding_l[0], &padding_r[0]),
10279  "could not create a descriptor for a pooling forward "
10280  "propagation primitive");
10281  }
10282  };
10283 
10287  primitive_desc() = default;
10288 
10299  primitive_desc(const desc &adesc, const engine &aengine,
10300  bool allow_empty = false)
10301  : dnnl::primitive_desc(
10302  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
10303 
10315  primitive_desc(const desc &adesc, const primitive_attr &attr,
10316  const engine &aengine, bool allow_empty = false)
10317  : dnnl::primitive_desc(
10318  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
10319 
10328  : dnnl::primitive_desc(pd, dnnl::primitive::kind::pooling_v2,
10331 
10333  memory::desc src_desc() const { return base::src_desc(0); }
10334 
10336  memory::desc dst_desc() const { return base::dst_desc(0); }
10337 
10340  };
10341 
10343  pooling_v2_forward() = default;
10344 
10350 };
10351 
10355  struct desc {
10357 
10381  desc(algorithm aalgorithm, const memory::desc &diff_src_desc,
10382  const memory::desc &diff_dst_desc, const memory::dims &strides,
10383  const memory::dims &kernel, const memory::dims &dilation,
10384  const memory::dims &padding_l, const memory::dims &padding_r) {
10385  memory::validate_dims(strides, diff_src_desc.data.ndims - 2);
10386  memory::validate_dims(kernel, diff_src_desc.data.ndims - 2);
10387  memory::validate_dims(padding_l, diff_src_desc.data.ndims - 2);
10388  memory::validate_dims(padding_r, diff_src_desc.data.ndims - 2);
10389  memory::validate_dims(dilation, diff_src_desc.data.ndims - 2);
10392  convert_to_c(aalgorithm), &diff_src_desc.data,
10393  &diff_dst_desc.data, &strides[0], &kernel[0],
10394  &dilation[0], &padding_l[0], &padding_r[0]),
10395  "could not create a descriptor for a pooling backward "
10396  "propagation primitive");
10397  }
10398  };
10399 
10404  primitive_desc() = default;
10405 
10419  primitive_desc(const desc &adesc, const engine &aengine,
10420  const pooling_v2_forward::primitive_desc &hint_fwd_pd,
10421  bool allow_empty = false)
10422  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
10423  hint_fwd_pd.get(), allow_empty) {}
10424 
10439  primitive_desc(const desc &adesc, const primitive_attr &attr,
10440  const engine &aengine,
10441  const pooling_v2_forward::primitive_desc &hint_fwd_pd,
10442  bool allow_empty = false)
10443  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
10444  hint_fwd_pd.get(), allow_empty) {}
10445 
10454  : dnnl::primitive_desc(pd, dnnl::primitive::kind::pooling_v2,
10456 
10459 
10462 
10465  };
10466 
10468  pooling_v2_backward() = default;
10469 
10475 };
10476 
10478 
10487 
10489 struct prelu_forward : public primitive {
10491  struct desc {
10492  dnnl_prelu_desc_t data;
10493 
10502  desc(prop_kind aprop_kind, const memory::desc &data_desc,
10503  const memory::desc &weight_desc) {
10505  dnnl::convert_to_c(aprop_kind),
10506  &data_desc.data, &weight_desc.data),
10507  "could not create a descriptor for a prelu forward "
10508  "propagation primitive");
10509  }
10510  };
10511 
10515  primitive_desc() = default;
10516 
10527  primitive_desc(const desc &adesc, const engine &aengine,
10528  bool allow_empty = false)
10529  : dnnl::primitive_desc(
10530  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
10531 
10543  primitive_desc(const desc &adesc, const primitive_attr &attr,
10544  const engine &aengine, bool allow_empty = false)
10545  : dnnl::primitive_desc(
10546  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
10547 
10555  : dnnl::primitive_desc(pd, dnnl::primitive::kind::prelu,
10558 
10560  memory::desc src_desc() const { return base::src_desc(0); }
10561 
10563  memory::desc dst_desc() const { return base::dst_desc(0); }
10564  };
10565 
10567  prelu_forward() = default;
10568 
10573 };
10574 
10576 struct prelu_backward : public primitive {
10578  struct desc {
10579  dnnl_prelu_desc_t data;
10580 
10589  desc(const memory::desc &data_desc, const memory::desc &weight_desc,
10590  const memory::desc &diff_data_desc,
10591  const memory::desc &diff_weights_desc) {
10593  dnnl_prelu_backward_desc_init(&data, &data_desc.data,
10594  &weight_desc.data, &diff_data_desc.data,
10595  &diff_weights_desc.data),
10596  "could not create a descriptor for a prelu backward "
10597  "propagation primitive");
10598  }
10599  };
10600 
10604  primitive_desc() = default;
10605 
10619  primitive_desc(const desc &adesc, const engine &aengine,
10620  const prelu_forward::primitive_desc &hint_fwd_pd,
10621  bool allow_empty = false)
10622  : dnnl::primitive_desc(&adesc.data, nullptr, aengine,
10623  hint_fwd_pd.get(), allow_empty) {}
10624 
10639  primitive_desc(const desc &adesc, const primitive_attr &attr,
10640  const engine &aengine,
10641  const prelu_forward::primitive_desc &hint_fwd_pd,
10642  bool allow_empty = false)
10643  : dnnl::primitive_desc(&adesc.data, &attr, aengine,
10644  hint_fwd_pd.get(), allow_empty) {}
10645 
10653  : dnnl::primitive_desc(pd, dnnl::primitive::kind::prelu,
10655 
10657  memory::desc src_desc() const { return base::src_desc(0); }
10658 
10661 
10664  };
10665 
10667  prelu_backward() = default;
10668 
10673 };
10674 
10676 
10685 
10687 struct reduction : public primitive {
10689  struct desc {
10690  dnnl_reduction_desc_t data;
10691 
10693  desc() = default;
10694 
10712  desc(algorithm aalgorithm, const memory::desc &src_desc,
10713  const memory::desc &dst_desc, float p, float eps) {
10715  dnnl_reduction_desc_init(&data, convert_to_c(aalgorithm),
10716  &src_desc.data, &dst_desc.data, p, eps),
10717  "could not create a reduction descriptor");
10718  }
10719  };
10720 
10724  primitive_desc() = default;
10725 
10734  primitive_desc(const desc &adesc, const engine &aengine,
10735  bool allow_empty = false)
10736  : dnnl::primitive_desc(
10737  &adesc.data, nullptr, aengine, nullptr, allow_empty) {}
10738 
10748  primitive_desc(const desc &adesc, const primitive_attr &attr,
10749  const engine &aengine, bool allow_empty = false)
10750  : dnnl::primitive_desc(
10751  &adesc.data, &attr, aengine, nullptr, allow_empty) {}
10752 
10759 
10761  memory::desc src_desc() const { return base::src_desc(0); }
10762 
10764  memory::desc dst_desc() const { return base::dst_desc(0); }
10765  };
10766 
10768  reduction() = default;
10769 
10772  reduction(const primitive_desc &pd) : primitive(pd) {}
10773 };
10774 
10776 
10778 
10784 
10787 
10789 enum class status {
10804 };
10805 
10807 inline status set_verbose(int level) {
10808  return static_cast<status>(dnnl_set_verbose(level));
10809 }
10810 
10812 inline const version_t *version() {
10813  return dnnl_version();
10814 }
10815 
10817 inline status set_jit_dump(int enable) {
10818  return static_cast<status>(dnnl_set_jit_dump(enable));
10819 }
10820 
10822 inline status set_jit_profiling_flags(unsigned flags) {
10823  return static_cast<status>(dnnl_set_jit_profiling_flags(flags));
10824 }
10825 
10827 inline status set_jit_profiling_jitdumpdir(const std::string &dir) {
10828  return static_cast<status>(dnnl_set_jit_profiling_jitdumpdir(dir.c_str()));
10829 }
10830 
10832 enum class cpu_isa {
10855 };
10856 
10859  return static_cast<status>(
10860  dnnl_set_max_cpu_isa(static_cast<dnnl_cpu_isa_t>(isa)));
10861 }
10862 
10865  return static_cast<cpu_isa>(dnnl_get_effective_cpu_isa());
10866 }
10867 
10869 enum class cpu_isa_hints {
10874 };
10875 
10878  return static_cast<status>(dnnl_set_cpu_isa_hints(
10879  static_cast<dnnl_cpu_isa_hints_t>(isa_hints)));
10880 }
10881 
10884  return static_cast<cpu_isa_hints>(dnnl_get_cpu_isa_hints());
10885 }
10886 
10888 
10894 
10898  int result = 0;
10900  "could not get primitive cache capacity");
10901  return result;
10902 }
10903 
10905 inline void set_primitive_cache_capacity(int capacity) {
10907  "could not set primitive cache capacity");
10908 }
10909 
10911 
10918 
10920 inline status sgemm(char transa, char transb, dnnl_dim_t M, dnnl_dim_t N,
10921  dnnl_dim_t K, float alpha, const float *A, dnnl_dim_t lda,
10922  const float *B, dnnl_dim_t ldb, float beta, float *C, dnnl_dim_t ldc) {
10923  return static_cast<status>(dnnl_sgemm(
10924  transa, transb, M, N, K, alpha, A, lda, B, ldb, beta, C, ldc));
10925 }
10926 
10928 inline status gemm_u8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M,
10929  dnnl_dim_t N, dnnl_dim_t K, float alpha, const uint8_t *A,
10930  dnnl_dim_t lda, uint8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
10931  float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co) {
10932  return static_cast<status>(dnnl_gemm_u8s8s32(transa, transb, offsetc, M, N,
10933  K, alpha, A, lda, ao, B, ldb, bo, beta, C, ldc, co));
10934 }
10935 
10937 inline status gemm_s8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M,
10938  dnnl_dim_t N, dnnl_dim_t K, float alpha, const int8_t *A,
10939  dnnl_dim_t lda, int8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo,
10940  float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co) {
10941  return static_cast<status>(dnnl_gemm_s8s8s32(transa, transb, offsetc, M, N,
10942  K, alpha, A, lda, ao, B, ldb, bo, beta, C, ldc, co));
10943 }
10944 
10946 
10947 // implementation section
10948 
10951  dnnl_primitive_t result;
10953  "could not create a primitive");
10954  reset(result);
10955 }
10956 
10957 inline primitive::primitive(const primitive_desc &pd) : primitive(pd.get()) {}
10958 
10959 inline void primitive::execute(const stream &astream,
10960  const std::unordered_map<int, memory> &args) const {
10961  std::vector<dnnl_exec_arg_t> c_args;
10962  c_args.reserve(args.size());
10963  for (const auto &a : args)
10964  c_args.push_back({a.first, a.second.get(true)});
10965 
10966  error::wrap_c_api(dnnl_primitive_execute(get(), astream.get(),
10967  (int)c_args.size(), c_args.data()),
10968  "could not execute a primitive");
10969 }
10970 
10972 
10973 #undef DNNL_DEFINE_BITMASK_OPS
10974 
10975 } // namespace dnnl
10976 
10978 
10981 namespace oneapi {
10982 // Note: without this guard, doxygen warns of potentially recursive namespace
10983 #ifndef DOXYGEN_SHOULD_SKIP_THIS
10985 namespace dnnl = ::dnnl;
10986 #endif
10987 } // namespace oneapi
10988 
10990 
10991 #endif /* ONEAPI_DNNL_DNNL_HPP */
algorithm
Kinds of algorithms.
Definition: dnnl.hpp:470
dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_data_qparams(dnnl_primitive_attr_t attr, const float scale, const float shift)
Set quantization scale and shift parameters for RNN data tensors.
dnnl_status_t DNNL_API dnnl_post_ops_get_params_sum_v2(const_dnnl_post_ops_t post_ops, int index, float *scale, dnnl_data_type_t *data_type)
Returns the parameters of an accumulation (sum) post-op with a data type parameter.
dnnl_status_t DNNL_API dnnl_post_ops_get_params_dw_k3s2p1(const_dnnl_post_ops_t post_ops, int index, dnnl_data_type_t *weights_data_type, dnnl_data_type_t *bias_data_type, dnnl_data_type_t *dst_data_type, dnnl_dim_t *count, int *mask, const float **scales)
Returns the parameters of an depthwise post-op with stride 2.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_scratchpad_mode(dnnl_primitive_attr_t attr, dnnl_scratchpad_mode_t mode)
Sets primitive attributes scratchpad mode.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_post_ops(const_dnnl_primitive_attr_t attr, const_dnnl_post_ops_t *post_ops)
Returns primitive attributes post-ops.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_rnn_weights_qparams(const_dnnl_primitive_attr_t attr, dnnl_dim_t *count, int *mask, const float **scales)
Returns the quantization scaling factors for RNN weights tensors.
dnnl_status_t DNNL_API dnnl_post_ops_append_dw_k3s1p1(dnnl_post_ops_t post_ops, dnnl_data_type_t weights_data_type, dnnl_data_type_t bias_data_type, dnnl_data_type_t dst_data_type, dnnl_dim_t count, int mask, const float *scales)
Appends a depthwise post-op convolution with stride 1.
dnnl_status_t DNNL_API dnnl_post_ops_destroy(dnnl_post_ops_t post_ops)
Destroys post-ops.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_zero_points(dnnl_primitive_attr_t attr, int arg, dnnl_dim_t count, int mask, const int32_t *zero_points)
Sets primitive attributes zero points for primitive operations for a given memory argument.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_post_ops(dnnl_primitive_attr_t attr, const_dnnl_post_ops_t post_ops)
Sets primitive attributes post-ops.
dnnl_status_t DNNL_API dnnl_post_ops_append_sum(dnnl_post_ops_t post_ops, float scale)
Appends an accumulation (sum) to post-ops.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_weights_qparams(dnnl_primitive_attr_t attr, dnnl_dim_t count, int mask, const float *scales)
Sets quantization scaling factors for RNN weights tensors.
dnnl_status_t DNNL_API dnnl_primitive_attr_destroy(dnnl_primitive_attr_t attr)
Destroys primitive attributes.
dnnl_status_t DNNL_API dnnl_post_ops_append_sum_v2(dnnl_post_ops_t post_ops, float scale, dnnl_data_type_t data_type)
Appends an accumulation v2 (sum) to post-ops.
int DNNL_API dnnl_post_ops_len(const_dnnl_post_ops_t post_ops)
Returns the length of post-ops.
dnnl_status_t DNNL_API dnnl_post_ops_append_binary(dnnl_post_ops_t post_ops, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src1_desc)
Appends a binary post-op.
dnnl_status_t DNNL_API dnnl_post_ops_append_dw_k3s2p1(dnnl_post_ops_t post_ops, dnnl_data_type_t weights_data_type, dnnl_data_type_t bias_data_type, dnnl_data_type_t dst_data_type, dnnl_dim_t count, int mask, const float *scales)
Appends a depthwise post-op convolution with stride 2.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_rnn_weights_projection_qparams(const_dnnl_primitive_attr_t attr, dnnl_dim_t *count, int *mask, const float **scales)
Returns the quantization scaling factors for RNN projection weights tensors.
dnnl_status_t DNNL_API dnnl_post_ops_get_params_eltwise(const_dnnl_post_ops_t post_ops, int index, float *scale, dnnl_alg_kind_t *alg_kind, float *alpha, float *beta)
Returns the parameters of an elementwise post-op.
dnnl_status_t DNNL_API dnnl_post_ops_create(dnnl_post_ops_t *post_ops)
Creates empty post-ops sequence.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_scales(dnnl_primitive_attr_t attr, int arg, dnnl_dim_t count, int mask, const float *scales)
Sets primitive attributes scaling factors for primitive operations for a given memory argument.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_scratchpad_mode(const_dnnl_primitive_attr_t attr, dnnl_scratchpad_mode_t *mode)
Returns the primitive attributes scratchpad mode.
dnnl_status_t DNNL_API dnnl_primitive_attr_clone(dnnl_primitive_attr_t *attr, const_dnnl_primitive_attr_t existing_attr)
Clones primitive attributes.
dnnl_status_t DNNL_API dnnl_post_ops_get_params_dw_k3s1p1(const_dnnl_post_ops_t post_ops, int index, dnnl_data_type_t *weights_data_type, dnnl_data_type_t *bias_data_type, dnnl_data_type_t *dst_data_type, dnnl_dim_t *count, int *mask, const float **scales)
Returns the parameters of an depthwise post-op with stride 1.
dnnl_primitive_kind_t DNNL_API dnnl_post_ops_get_kind(const_dnnl_post_ops_t post_ops, int index)
Returns the kind of a post-op entry.
scratchpad_mode
Scratchpad mode.
Definition: dnnl.hpp:401
dnnl_status_t DNNL_API dnnl_primitive_attr_set_rnn_weights_projection_qparams(dnnl_primitive_attr_t attr, dnnl_dim_t count, int mask, const float *scales)
Sets quantization scaling factors for RNN projection weights tensors.
prop_kind
Propagation kind.
Definition: dnnl.hpp:435
dnnl_scratchpad_mode_t
Scratchpad mode.
Definition: dnnl_types.h:2201
dnnl_status_t DNNL_API dnnl_post_ops_append_eltwise(dnnl_post_ops_t post_ops, float scale, dnnl_alg_kind_t alg_kind, float alpha, float beta)
Appends an elementwise post-op.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_zero_points(const_dnnl_primitive_attr_t attr, int arg, dnnl_dim_t *count, int *mask, const int32_t **zero_points)
Returns count, correspondence zero point mask, and a pointer to a constant int32_t array of zero_poin...
dnnl_status_t DNNL_API dnnl_post_ops_get_params_sum(const_dnnl_post_ops_t post_ops, int index, float *scale)
Returns the parameters of an accumulation (sum) post-op.
dnnl_status_t DNNL_API dnnl_post_ops_get_params_binary(const_dnnl_post_ops_t post_ops, int index, dnnl_alg_kind_t *alg_kind, const dnnl_memory_desc_t **src1_desc)
Returns the parameters of a binary post-op.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_rnn_data_qparams(const_dnnl_primitive_attr_t attr, float *scale, float *shift)
Returns the quantization scale and shift parameters for RNN data tensors.
dnnl_status_t DNNL_API dnnl_primitive_attr_set_output_scales(dnnl_primitive_attr_t attr, dnnl_dim_t count, int mask, const float *scales)
Sets output scaling factors correspondence mask and values.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_scales(dnnl_primitive_attr_t attr, int arg, dnnl_dim_t *count, int *mask, const float **scales)
Returns primitive attributes scaling factors correspondence mask and values for a given memory argume...
dnnl_status_t DNNL_API dnnl_primitive_attr_create(dnnl_primitive_attr_t *attr)
Creates an empty (default) primitive attributes with all the parameters set to their default values.
dnnl_status_t DNNL_API dnnl_primitive_attr_get_output_scales(const_dnnl_primitive_attr_t attr, dnnl_dim_t *count, int *mask, const float **scales)
Returns primitive attributes output scaling factors correspondence mask and values.
@ resampling_linear
Linear (Bilinear, Trilinear) resampling method.
@ binary_mul
Binary mul.
@ resampling_nearest
Nearest Neighbor resampling method.
@ eltwise_elu_use_dst_for_bwd
Elementwise: exponential linear unit (ELU) (dst for backward)
@ eltwise_tanh_use_dst_for_bwd
Elementwise: hyperbolic tangent non-linearity (tanh) (dst for backward)
@ reduction_norm_lp_power_p_sum
Reduction using norm_lp_power_p_sum operation.
@ eltwise_linear
Elementwise: linear.
@ eltwise_clip_v2
Eltwise: clip version 2.
@ eltwise_soft_relu
Elementwise: soft_relu.
@ vanilla_gru
GRU cell.
@ eltwise_logistic
Elementwise: logistic.
@ binary_div
Binary div.
@ eltwise_clip
Elementwise: clip.
@ eltwise_abs
Elementwise: abs.
@ eltwise_pow
Elementwise: pow.
@ eltwise_tanh
Elementwise: hyperbolic tangent non-linearity (tanh)
@ eltwise_logistic_use_dst_for_bwd
Elementwise: logistic (dst for backward)
@ eltwise_bounded_relu
Elementwise: bounded_relu.
@ reduction_norm_lp_power_p_max
Reduction using norm_lp_power_p_max operation.
@ reduction_max
Reduction using max operation.
@ eltwise_clip_v2_use_dst_for_bwd
Elementwise: clip version 2 (dst for backward)
@ eltwise_square
Elementwise: square.
@ binary_max
Binary max.
@ convolution_direct
Direct convolution.
@ eltwise_exp
Elementwise: exponent.
@ reduction_norm_lp_max
Reduction using norm_lp_max operation.
@ eltwise_elu
Elementwise: exponential linear unit (ELU)
@ convolution_winograd
Winograd convolution.
@ vanilla_lstm
LSTM cell.
@ deconvolution_direct
Direct deconvolution.
@ pooling_avg
Average pooling exclude padding, alias for dnnl::algorithm::pooling_avg_include_padding.
@ lbr_gru
GRU cell with linear before reset.
@ pooling_avg_exclude_padding
Average pooling exclude padding.
@ eltwise_gelu
Elementwise: gelu alias for dnnl::algorithm::eltwise_gelu_tanh.
@ eltwise_sqrt
Elementwise: square root.
@ pooling_max
Max pooling.
@ reduction_min
Reduction using min operation.
@ eltwise_gelu_erf
Elementwise: erf-based gelu.
@ eltwise_swish
Elementwise: swish ( )
@ binary_sub
Binary sub.
@ lrn_within_channel
LRN within a single channel.
@ reduction_mul
Reduction using mul operation.
@ vanilla_rnn
RNN cell.
@ binary_add
Binary add.
@ lrn_across_channels
Local response normalization (LRN) across multiple channels.
@ eltwise_relu
Elementwise: rectified linear unit (ReLU)
@ eltwise_gelu_tanh
Elementwise: tanh-based gelu.
@ eltwise_relu_use_dst_for_bwd
Elementwise: rectified linar unit (ReLU) (dst for backward)
@ eltwise_logsigmoid
Elementwise: logsigmoid.
@ convolution_auto
Convolution algorithm that is chosen to be either direct or Winograd automatically.
@ binary_min
Binary min.
@ eltwise_exp_use_dst_for_bwd
Elementwise: exponent (dst for backward)
@ eltwise_round
Elementwise: round.
@ eltwise_sqrt_use_dst_for_bwd
Elementwise: square root (dst for backward)
@ pooling_avg_include_padding
Average pooling include padding.
@ reduction_norm_lp_sum
Reduction using norm_lp_sum operation.
@ reduction_mean
Reduction using mean operation.
@ deconvolution_winograd
Winograd deconvolution.
@ eltwise_log
Elementwise: natural logarithm.
@ undef
Undefined algorithm.
@ reduction_sum
Reduction using sum operation.
@ library
The library manages the scratchpad allocation according to the policy specified by the DNNL_ENABLE_CO...
@ user
The user manages the scratchpad allocation by querying and providing the scratchpad memory to primiti...
@ backward
Backward propagation (with respect to all parameters).
@ backward_weights
Backward weights propagation.
@ forward_training
Forward data propagation (training mode).
@ forward_inference
Forward data propagation (inference mode).
@ forward_scoring
Forward data propagation, alias for dnnl::prop_kind::forward_inference.
@ forward
Forward data propagation, alias for dnnl::prop_kind::forward_training.
@ backward_data
Backward data propagation.
@ backward_bias
Backward bias propagation.
@ undef
Undefined propagation kind.
@ dnnl_scratchpad_mode_user
The user manages the scratchpad allocation by querying and providing the scratchpad memory to primiti...
Definition: dnnl_types.h:2223
@ dnnl_scratchpad_mode_library
The library manages the scratchpad allocation according to the policy specified by the DNNL_ENABLE_CO...
Definition: dnnl_types.h:2218
dnnl_status_t DNNL_API dnnl_batch_normalization_backward_desc_init(dnnl_batch_normalization_desc_t *bnrm_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *diff_data_desc, const dnnl_memory_desc_t *data_desc, float epsilon, unsigned flags)
Initializes a descriptor for a batch normalization backward propagation primitive.
dnnl_status_t DNNL_API dnnl_batch_normalization_forward_desc_init(dnnl_batch_normalization_desc_t *bnrm_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *data_desc, float epsilon, unsigned flags)
Initializes a descriptor for a batch normalization forward propagation primitive.
dnnl_status_t DNNL_API dnnl_binary_desc_init(dnnl_binary_desc_t *binary_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src0_desc, const dnnl_memory_desc_t *src1_desc, const dnnl_memory_desc_t *dst_desc)
Initializes a descriptor for a binary primitive.
dnnl_status_t DNNL_API dnnl_gemm_s8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const int8_t *A, dnnl_dim_t lda, int8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo, float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co)
Performs integer matrix-matrix multiply on 8-bit signed matrix A, 8-bit signed matrix B,...
status gemm_u8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const uint8_t *A, dnnl_dim_t lda, uint8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo, float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co)
Performs integer matrix-matrix multiply on 8-bit unsigned matrix A, 8-bit signed matrix B,...
Definition: dnnl.hpp:10928
status gemm_s8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const int8_t *A, dnnl_dim_t lda, int8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo, float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co)
Performs integer matrix-matrix multiply on 8-bit signed matrix A, 8-bit signed matrix B,...
Definition: dnnl.hpp:10937
dnnl_status_t DNNL_API dnnl_sgemm(char transa, char transb, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const float *A, dnnl_dim_t lda, const float *B, dnnl_dim_t ldb, float beta, float *C, dnnl_dim_t ldc)
Performs single-precision matrix-matrix multiply.
status sgemm(char transa, char transb, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const float *A, dnnl_dim_t lda, const float *B, dnnl_dim_t ldb, float beta, float *C, dnnl_dim_t ldc)
Performs single-precision matrix-matrix multiply.
Definition: dnnl.hpp:10920
dnnl_status_t DNNL_API dnnl_gemm_u8s8s32(char transa, char transb, char offsetc, dnnl_dim_t M, dnnl_dim_t N, dnnl_dim_t K, float alpha, const uint8_t *A, dnnl_dim_t lda, uint8_t ao, const int8_t *B, dnnl_dim_t ldb, int8_t bo, float beta, int32_t *C, dnnl_dim_t ldc, const int32_t *co)
Performs integer matrix-matrix multiply on 8-bit unsigned matrix A, 8-bit signed matrix B,...
dnnl_status_t DNNL_API dnnl_concat_primitive_desc_create(dnnl_primitive_desc_t *concat_primitive_desc, const dnnl_memory_desc_t *dst_desc, int n, int concat_dimension, const dnnl_memory_desc_t *src_descs, const_dnnl_primitive_attr_t attr, dnnl_engine_t engine)
Creates a primitive descriptor for an out-of-place concatenation primitive.
dnnl_status_t DNNL_API dnnl_convolution_forward_desc_init(dnnl_convolution_desc_t *conv_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc, const dnnl_dims_t strides, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a convolution forward propagation primitive.
dnnl_status_t DNNL_API dnnl_dilated_convolution_backward_weights_desc_init(dnnl_convolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *diff_weights_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t dilates, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a dilated convolution weights gradient primitive.
dnnl_status_t DNNL_API dnnl_dilated_convolution_forward_desc_init(dnnl_convolution_desc_t *conv_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc, const dnnl_dims_t strides, const dnnl_dims_t dilates, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a dilated convolution forward propagation primitive.
dnnl_status_t DNNL_API dnnl_convolution_backward_weights_desc_init(dnnl_convolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *diff_weights_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a convolution weights gradient primitive.
dnnl_status_t DNNL_API dnnl_dilated_convolution_backward_data_desc_init(dnnl_convolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t dilates, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a dilated convolution backward propagation primitive.
dnnl_status_t DNNL_API dnnl_convolution_backward_data_desc_init(dnnl_convolution_desc_t *conv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a convolution backward propagation primitive.
dnnl_status_t DNNL_API dnnl_dilated_deconvolution_backward_data_desc_init(dnnl_deconvolution_desc_t *deconv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t dilates, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a dilated deconvolution backward propagation primitive.
dnnl_status_t DNNL_API dnnl_deconvolution_forward_desc_init(dnnl_deconvolution_desc_t *deconv_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc, const dnnl_dims_t strides, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a deconvolution forward propagation primitive.
dnnl_status_t DNNL_API dnnl_deconvolution_backward_weights_desc_init(dnnl_deconvolution_desc_t *deconv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *diff_weights_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a deconvolution weights gradient primitive.
dnnl_status_t DNNL_API dnnl_deconvolution_backward_data_desc_init(dnnl_deconvolution_desc_t *deconv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a deconvolution backward propagation primitive.
dnnl_status_t DNNL_API dnnl_dilated_deconvolution_forward_desc_init(dnnl_deconvolution_desc_t *deconv_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc, const dnnl_dims_t strides, const dnnl_dims_t dilates, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a dilated deconvolution forward propagation primitive.
dnnl_status_t DNNL_API dnnl_dilated_deconvolution_backward_weights_desc_init(dnnl_deconvolution_desc_t *deconv_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *diff_weights_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t dilates, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for a dilated deconvolution weights gradient primitive.
dnnl_status_t DNNL_API dnnl_eltwise_forward_desc_init(dnnl_eltwise_desc_t *eltwise_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *data_desc, float alpha, float beta)
Initializes a descriptor for eltwise forward propagation primitive.
dnnl_status_t DNNL_API dnnl_eltwise_backward_desc_init(dnnl_eltwise_desc_t *eltwise_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_data_desc, const dnnl_memory_desc_t *data_desc, float alpha, float beta)
Initializes a descriptor for eltwise backward propagation primitive.
dnnl_engine_kind_t
Kinds of engines.
Definition: dnnl_types.h:2147
dnnl_status_t DNNL_API dnnl_engine_get_kind(dnnl_engine_t engine, dnnl_engine_kind_t *kind)
Returns the kind of an engine.
dnnl_status_t DNNL_API dnnl_engine_destroy(dnnl_engine_t engine)
Destroys an engine.
dnnl_status_t DNNL_API dnnl_engine_create(dnnl_engine_t *engine, dnnl_engine_kind_t kind, size_t index)
Creates an engine.
size_t DNNL_API dnnl_engine_get_count(dnnl_engine_kind_t kind)
Returns the number of engines of a particular kind.
dnnl_engine_kind_t convert_to_c(engine::kind akind)
Converts engine kind enum value from C++ API to C API type.
Definition: dnnl.hpp:961
@ dnnl_gpu
GPU engine.
Definition: dnnl_types.h:2153
@ dnnl_cpu
CPU engine.
Definition: dnnl_types.h:2151
@ dnnl_any_engine
An unspecified engine.
Definition: dnnl_types.h:2149
dnnl_status_t DNNL_API dnnl_inner_product_forward_desc_init(dnnl_inner_product_desc_t *ip_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc)
Initializes descriptor for inner product forward propagation.
dnnl_status_t DNNL_API dnnl_inner_product_backward_weights_desc_init(dnnl_inner_product_desc_t *ip_desc, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *diff_weights_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_desc)
Initializes descriptor for inner product weights gradient primitive.
dnnl_status_t DNNL_API dnnl_inner_product_backward_data_desc_init(dnnl_inner_product_desc_t *ip_desc, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *diff_dst_desc)
Initializes descriptor for inner product backward propagation.
dnnl_status_t DNNL_API dnnl_layer_normalization_backward_desc_init(dnnl_layer_normalization_desc_t *lnrm_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *diff_data_desc, const dnnl_memory_desc_t *data_desc, const dnnl_memory_desc_t *stat_desc, float epsilon, unsigned flags)
Initializes a descriptor for a layer normalization backward propagation primitive.
dnnl_status_t DNNL_API dnnl_layer_normalization_forward_desc_init(dnnl_layer_normalization_desc_t *lnrm_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *data_desc, const dnnl_memory_desc_t *stat_desc, float epsilon, unsigned flags)
Initializes a descriptor for layer normalization forward propagation primitive.
dnnl_status_t DNNL_API dnnl_logsoftmax_forward_desc_init(dnnl_logsoftmax_desc_t *logsoftmax_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *data_desc, int logsoftmax_axis)
Initializes a descriptor for logsoftmax forward propagation primitive.
dnnl_status_t DNNL_API dnnl_logsoftmax_backward_desc_init(dnnl_logsoftmax_desc_t *logsoftmax_desc, const dnnl_memory_desc_t *diff_data_desc, const dnnl_memory_desc_t *data_desc, int logsoftmax_axis)
Initializes a descriptor for logsoftmax backward propagation primitive.
dnnl_status_t DNNL_API dnnl_lrn_backward_desc_init(dnnl_lrn_desc_t *lrn_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_data_desc, const dnnl_memory_desc_t *data_desc, dnnl_dim_t local_size, float alpha, float beta, float k)
Initializes a descriptor for LRN backward propagation primitive.
dnnl_status_t DNNL_API dnnl_lrn_forward_desc_init(dnnl_lrn_desc_t *lrn_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *data_desc, dnnl_dim_t local_size, float alpha, float beta, float k)
Initializes a descriptor for LRN forward propagation primitive.
dnnl_status_t DNNL_API dnnl_matmul_desc_init(dnnl_matmul_desc_t *matmul_desc, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_desc)
Initializes a matrix multiplication descriptor.
dnnl_data_type_t
Data type specification.
Definition: dnnl_types.h:62
dnnl_status_t DNNL_API dnnl_memory_desc_init_submemory(dnnl_memory_desc_t *memory_desc, const dnnl_memory_desc_t *parent_memory_desc, const dnnl_dims_t dims, const dnnl_dims_t offsets)
Initializes a memory descriptor for a region inside an area described by an existing memory descripto...
dnnl_format_tag_t
Memory format tag specification.
Definition: dnnl_types.h:164
dnnl_status_t DNNL_API dnnl_memory_desc_permute_axes(dnnl_memory_desc_t *out_memory_desc, const dnnl_memory_desc_t *in_memory_desc, const int *permutation)
Initializes a memory descriptor by permuting axes in an existing one.
dnnl_status_t DNNL_API dnnl_memory_unmap_data(const_dnnl_memory_t memory, void *mapped_ptr)
Unmaps a memory object and writes back any changes made to the previously mapped memory buffer.
dnnl_status_t DNNL_API dnnl_memory_create(dnnl_memory_t *memory, const dnnl_memory_desc_t *memory_desc, dnnl_engine_t engine, void *handle)
Creates a memory object.
dnnl_status_t DNNL_API dnnl_memory_get_engine(const_dnnl_memory_t memory, dnnl_engine_t *engine)
Returns the engine of a memory object.
dnnl_status_t DNNL_API dnnl_memory_desc_reshape(dnnl_memory_desc_t *out_memory_desc, const dnnl_memory_desc_t *in_memory_desc, int ndims, const dnnl_dims_t dims)
Initializes a memory descriptor by reshaping an existing one.
dnnl_status_t DNNL_API dnnl_memory_get_memory_desc(const_dnnl_memory_t memory, const dnnl_memory_desc_t **memory_desc)
Returns the memory descriptor for a memory object.
dnnl_status_t DNNL_API dnnl_memory_get_data_handle(const_dnnl_memory_t memory, void **handle)
Returns memory object's data handle.
dnnl_status_t DNNL_API dnnl_memory_set_data_handle_v2(dnnl_memory_t memory, void *handle, dnnl_stream_t stream)
Sets the underlying memory buffer.
dnnl_status_t DNNL_API dnnl_memory_desc_init_by_strides(dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims, dnnl_data_type_t data_type, const dnnl_dims_t strides)
Initializes a memory descriptor using dimensions and strides.
int64_t dnnl_dim_t
A type to describe tensor dimension.
Definition: dnnl_types.h:1333
dnnl_status_t DNNL_API dnnl_memory_destroy(dnnl_memory_t memory)
Destroys a memory object.
int DNNL_API dnnl_memory_desc_equal(const dnnl_memory_desc_t *lhs, const dnnl_memory_desc_t *rhs)
Compares two memory descriptors.
#define DNNL_MAX_NDIMS
Maximum number of dimensions a tensor can have.
Definition: dnnl_types.h:1301
dnnl_status_t DNNL_API dnnl_memory_map_data(const_dnnl_memory_t memory, void **mapped_ptr)
Maps a memory object and returns a host-side pointer to a memory buffer with a copy of its contents.
size_t DNNL_API dnnl_memory_desc_get_size(const dnnl_memory_desc_t *memory_desc)
Returns the size of a memory descriptor.
#define DNNL_MEMORY_ALLOCATE
Special pointer value that indicates that the library needs to allocate an underlying buffer for a me...
Definition: dnnl_types.h:1510
dnnl_status_t DNNL_API dnnl_memory_desc_init_by_tag(dnnl_memory_desc_t *memory_desc, int ndims, const dnnl_dims_t dims, dnnl_data_type_t data_type, dnnl_format_tag_t tag)
Initializes a memory descriptor using dimensions and memory format tag.
@ dnnl_f16
16-bit/half-precision floating point.
Definition: dnnl_types.h:66
@ dnnl_bf16
non-standard 16-bit (bfloat16 w/ 7 bit mantissa) floating point.
Definition: dnnl_types.h:68
@ dnnl_f32
32-bit/single-precision floating point.
Definition: dnnl_types.h:70
@ dnnl_data_type_undef
Undefined data type, used for empty memory descriptors.
Definition: dnnl_types.h:64
@ dnnl_s8
8-bit signed integer.
Definition: dnnl_types.h:74
@ dnnl_s32
32-bit signed integer.
Definition: dnnl_types.h:72
@ dnnl_u8
8-bit unsigned integer.
Definition: dnnl_types.h:76
@ dnnl_abcdefhg
permuted 8D tensor
Definition: dnnl_types.h:216
@ dnnl_aBCdef2b4c2b
6D tensor blocked by 3rd dimension with block size 4
Definition: dnnl_types.h:362
@ dnnl_abcdefghi
plain 9D tensor
Definition: dnnl_types.h:186
@ dnnl_acdeb
permuted 5D tensor
Definition: dnnl_types.h:199
@ dnnl_abcdefgh
plain 8D tensor
Definition: dnnl_types.h:185
@ dnnl_abcdefghikj
permuted 11D tensor
Definition: dnnl_types.h:219
@ dnnl_ab
plain 2D tensor
Definition: dnnl_types.h:178
@ dnnl_ABcd8b8a
4D tensor blocked by 1st and 2nd dimension with block size 8
Definition: dnnl_types.h:288
@ dnnl_cdba
permuted 4D tensor
Definition: dnnl_types.h:208
@ dnnl_abcdefghijkl
plain 12D tensor
Definition: dnnl_types.h:189
@ dnnl_aBcdef4b
6D tensor blocked by 2nd dimension with block size 4
Definition: dnnl_types.h:364
@ dnnl_abcdegf
permuted 7D tensor
Definition: dnnl_types.h:215
@ dnnl_abcdfe
permuted 6D tensor
Definition: dnnl_types.h:214
@ dnnl_aBcd4b
4D tensor blocked by 2nd dimension with block size 4
Definition: dnnl_types.h:263
@ dnnl_nCdhw16c
5D CNN activations tensor blocked by channels with block size 16, an alias to dnnl_aBcde16b
Definition: dnnl_types.h:682
@ dnnl_abcde
plain 5D tensor
Definition: dnnl_types.h:182
@ dnnl_decab
permuted 5D tensor
Definition: dnnl_types.h:211
@ dnnl_bca
permuted 3D tensor
Definition: dnnl_types.h:204
@ dnnl_aBcde4b
5D tensor blocked by 2nd dimension with block size 4
Definition: dnnl_types.h:315
@ dnnl_aBc16b
3D tensor blocked by 2nd dimension with block size 16
Definition: dnnl_types.h:229
@ dnnl_aBcdef16b
6D tensor blocked by 2nd dimension with block size 16
Definition: dnnl_types.h:354
@ dnnl_aBCde2b4c2b
5D tensor blocked by 3rd dimension with block size 4
Definition: dnnl_types.h:352
@ dnnl_aBc4b
3D tensor blocked by 2nd dimension with block size 4
Definition: dnnl_types.h:235
@ dnnl_abcdefghijk
plain 11D tensor
Definition: dnnl_types.h:188
@ dnnl_bacde
permuted 5D tensor
Definition: dnnl_types.h:203
@ dnnl_aBcd16b
4D tensor blocked by 2nd dimension with block size 16
Definition: dnnl_types.h:255
@ dnnl_cba
permuted 3D tensor
Definition: dnnl_types.h:207
@ dnnl_ba
permuted 2D tensor
Definition: dnnl_types.h:200
@ dnnl_ABcde2b8a4b
5D tensor blocked by 1st dimension with block size 8
Definition: dnnl_types.h:304
@ dnnl_abcd
plain 4D tensor
Definition: dnnl_types.h:180
@ dnnl_format_tag_undef
Undefined memory format tag.
Definition: dnnl_types.h:166
@ dnnl_nCdhw4c
5D CNN activations tensor blocked by channels with block size 4, an alias to dnnl_aBcde4b
Definition: dnnl_types.h:685
@ dnnl_defcab
permuted 6D tensor
Definition: dnnl_types.h:212
@ dnnl_abcdef
plain 6D tensor
Definition: dnnl_types.h:183
@ dnnl_nChw8c
4D CNN activations tensor blocked by channels with block size 8, an alias to dnnl_aBcd8b
Definition: dnnl_types.h:700
@ dnnl_a
plain 1D tensor
Definition: dnnl_types.h:177
@ dnnl_nChw4c
4D CNN activations tensor blocked by channels with block size 4, an alias to dnnl_aBcd4b
Definition: dnnl_types.h:697
@ dnnl_acbdef
permuted 6D tensor
Definition: dnnl_types.h:197
@ dnnl_acdb
permuted 4D tensor
Definition: dnnl_types.h:198
@ dnnl_aBcd8b
4D tensor blocked by 2nd dimension with block size 8
Definition: dnnl_types.h:282
@ dnnl_aBc8b
3D tensor blocked by 2nd dimension with block size 8
Definition: dnnl_types.h:245
@ dnnl_nCw4c
3D CNN activations tensor blocked by channels with block size 4, an alias to dnnl_aBc4b
Definition: dnnl_types.h:709
@ dnnl_abcdefg
plain 7D tensor
Definition: dnnl_types.h:184
@ dnnl_aBcde8b
5D tensor blocked by 2nd dimension with block size 8
Definition: dnnl_types.h:330
@ dnnl_nChw16c
4D CNN activations tensor blocked by channels with block size 16, an alias to dnnl_aBcd16b
Definition: dnnl_types.h:694
@ dnnl_abdfce
permuted 6D tensor
Definition: dnnl_types.h:424
@ dnnl_abdec
permuted 5D tensor
Definition: dnnl_types.h:194
@ dnnl_bacd
permuted 4D tensor
Definition: dnnl_types.h:202
@ dnnl_nCdhw8c
5D CNN activations tensor blocked by channels with block size 8, an alias to dnnl_aBcde8b
Definition: dnnl_types.h:688
@ dnnl_aBcde32b
5D tensor blocked by 2nd dimension with block size 32
Definition: dnnl_types.h:313
@ dnnl_abced
permuted 5D tensor
Definition: dnnl_types.h:213
@ dnnl_bcda
permuted 4D tensor
Definition: dnnl_types.h:205
@ dnnl_acbde
permuted 5D tensor
Definition: dnnl_types.h:196
@ dnnl_aBCd2b4c2b
4D tensor blocked by 3rd dimension with block size 4
Definition: dnnl_types.h:300
@ dnnl_abcdefgih
permuted 9D tensor
Definition: dnnl_types.h:217
@ dnnl_bcdea
permuted 5D tensor
Definition: dnnl_types.h:206
@ dnnl_abdefc
permuted 6D tensor
Definition: dnnl_types.h:425
@ dnnl_aBcde16b
5D tensor blocked by 2nd dimension with block size 16
Definition: dnnl_types.h:306
@ dnnl_nCw8c
3D CNN activations tensor blocked by channels with block size 8, an alias to dnnl_aBc8b
Definition: dnnl_types.h:712
@ dnnl_abdc
permuted 4D tensor
Definition: dnnl_types.h:193
@ dnnl_ABcde4b16a4b
5D tensor blocked by 1st dimension with block size 16
Definition: dnnl_types.h:302
@ dnnl_aBcd32b
4D tensor blocked by 2nd dimension with block size 32
Definition: dnnl_types.h:261
@ dnnl_abcdefghijlk
permuted 12D tensor
Definition: dnnl_types.h:220
@ dnnl_format_tag_last
Just a sentinel, not real memory format tag.
Definition: dnnl_types.h:543
@ dnnl_abc
plain 3D tensor
Definition: dnnl_types.h:179
@ dnnl_bac
permuted 3D tensor
Definition: dnnl_types.h:201
@ dnnl_dcab
permuted 4D tensor
Definition: dnnl_types.h:209
@ dnnl_cdeba
permuted 5D tensor
Definition: dnnl_types.h:210
@ dnnl_acb
permuted 3D tensor
Definition: dnnl_types.h:195
@ dnnl_aBc32b
3D tensor blocked by 2nd dimension with block size 32
Definition: dnnl_types.h:233
@ dnnl_abcdefghji
permuted 10D tensor
Definition: dnnl_types.h:218
@ dnnl_nCw16c
3D CNN activations tensor blocked by channels with block size 16, an alias to dnnl_aBc16b
Definition: dnnl_types.h:706
@ dnnl_aBCdef2c8b4c
6D tensor blocked by 2nd dimension with block size 8
Definition: dnnl_types.h:359
@ dnnl_abcdefghij
plain 10D tensor
Definition: dnnl_types.h:187
@ dnnl_format_tag_any
Undefined memory format tag.
Definition: dnnl_types.h:169
@ dnnl_blocked
A tensor in a generic format described by the stride and blocking values in each dimension.
Definition: dnnl_types.h:89
@ dnnl_format_kind_wino
Weights format used in 8bit Winograd convolution.
Definition: dnnl_types.h:91
@ dnnl_format_kind_any
Unspecified format kind.
Definition: dnnl_types.h:85
@ dnnl_format_kind_undef
Undefined memory format kind, used for empty memory descriptors.
Definition: dnnl_types.h:82
@ dnnl_format_kind_rnn_packed
Packed weights format used in RNN.
Definition: dnnl_types.h:93
dnnl_status_t DNNL_API dnnl_pooling_v2_backward_desc_init(dnnl_pooling_v2_desc_t *pool_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t kernel, const dnnl_dims_t dilation, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for pooling v2 (pooling with dilation support) backward propagation primitiv...
dnnl_status_t DNNL_API dnnl_pooling_v2_forward_desc_init(dnnl_pooling_v2_desc_t *pool_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *dst_desc, const dnnl_dims_t strides, const dnnl_dims_t kernel, const dnnl_dims_t dilation, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for pooling v2 (pooling with dilation support) forward propagation primitive...
dnnl_status_t DNNL_API dnnl_pooling_forward_desc_init(dnnl_pooling_desc_t *pool_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *dst_desc, const dnnl_dims_t strides, const dnnl_dims_t kernel, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for pooling forward propagation primitive.
dnnl_status_t DNNL_API dnnl_pooling_backward_desc_init(dnnl_pooling_desc_t *pool_desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *diff_dst_desc, const dnnl_dims_t strides, const dnnl_dims_t kernel, const dnnl_dims_t padding_l, const dnnl_dims_t padding_r)
Initializes a descriptor for pooling backward propagation primitive.
dnnl_status_t DNNL_API dnnl_prelu_forward_desc_init(dnnl_prelu_desc_t *prelu_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *data_desc, const dnnl_memory_desc_t *weights_desc)
Initializes a descriptor for PReLU (leaky ReLU with trainable alpha parameter) forward propagation pr...
dnnl_status_t DNNL_API dnnl_prelu_backward_desc_init(dnnl_prelu_desc_t *prelu_desc, const dnnl_memory_desc_t *data_desc, const dnnl_memory_desc_t *weights_desc, const dnnl_memory_desc_t *diff_data_desc, const dnnl_memory_desc_t *diff_weights_desc)
Initializes a descriptor for PReLU (leaky ReLU with trainable alpha parameter) backward propagation p...
void set_primitive_cache_capacity(int capacity)
Sets a number of primitives that can be held in the primitive cache at a time.
Definition: dnnl.hpp:10905
dnnl_status_t DNNL_API dnnl_set_primitive_cache_capacity(int capacity)
Sets a number of primitives that can be held in the primitive cache at a time.
dnnl_status_t DNNL_API dnnl_get_primitive_cache_capacity(int *capacity)
Returns the number of primitives that can be held in the primitive cache at the same time.
int get_primitive_cache_capacity()
Returns the number of primitives that can be held in the primitive cache at the same time.
Definition: dnnl.hpp:10897
dnnl_status_t DNNL_API dnnl_primitive_desc_query(const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what, int index, void *result)
Queries a primitive descriptor for various pieces of information.
#define DNNL_ARG_DST_ITER
A special mnemonic for RNN input recurrent hidden state vector.
Definition: dnnl_types.h:2318
dnnl_status_t DNNL_API dnnl_primitive_desc_iterator_destroy(dnnl_primitive_desc_iterator_t iterator)
Destroys a primitive descriptor iterator.
#define DNNL_ARG_WEIGHTS_LAYER
A special mnemonic for RNN weights applied to the layer input.
Definition: dnnl_types.h:2336
#define DNNL_ARG_DIFF_BIAS
Gradient (diff) of the bias tensor argument.
Definition: dnnl_types.h:2443
#define DNNL_ARG_DIFF_SRC_ITER_C
A special mnemonic for gradient (diff) of RNN input recurrent cell state vector.
Definition: dnnl_types.h:2389
#define DNNL_ARG_DIFF_SRC_LAYER
A special mnemonic for gradient (diff) of RNN input vector.
Definition: dnnl_types.h:2377
#define DNNL_ARG_DIFF_WEIGHTS_PEEPHOLE
A special mnemonic for diff of RNN weights applied to the peephole weights.
Definition: dnnl_types.h:2434
#define DNNL_ARG_WEIGHTS_PROJECTION
A special mnemonic for RNN weights applied to the projection weights.
Definition: dnnl_types.h:2354
dnnl_normalization_flags_t
Flags for normalization primitives.
Definition: dnnl_types.h:1241
#define DNNL_ARG_DIFF_WEIGHTS_PROJECTION
A special mnemonic for diff of RNN weights applied to the projection weights.
Definition: dnnl_types.h:2440
const dnnl_memory_desc_t DNNL_API * dnnl_primitive_desc_query_md(const_dnnl_primitive_desc_t primitive_desc, dnnl_query_t what, int index)
Queries primitive descriptor for a memory descriptor.
dnnl_status_t DNNL_API dnnl_primitive_desc_get_attr(const_dnnl_primitive_desc_t primitive_desc, const_dnnl_primitive_attr_t *attr)
Returns a constant reference to the attributes of a primitive descriptor.
#define DNNL_ARG_DIFF_WEIGHTS_ITER
A special mnemonic for diff of RNN weights applied to the recurrent input.
Definition: dnnl_types.h:2428
#define DNNL_ARG_DIFF_SRC_ITER
A special mnemonic for gradient (diff) of RNN input recurrent hidden state vector.
Definition: dnnl_types.h:2383
#define DNNL_ARG_DIFF_DST_ITER_C
A special mnemonic for gradient (diff) of RNN input recurrent cell state vector.
Definition: dnnl_types.h:2410
dnnl_status_t DNNL_API dnnl_primitive_execute(const_dnnl_primitive_t primitive, dnnl_stream_t stream, int nargs, const dnnl_exec_arg_t *args)
Executes a primitive.
#define DNNL_ARG_WEIGHTS_ITER
A special mnemonic for RNN weights applied to the recurrent input.
Definition: dnnl_types.h:2342
dnnl_status_t DNNL_API dnnl_primitive_desc_iterator_next(dnnl_primitive_desc_iterator_t iterator)
Advances the primitive descriptor iterator to point to the next available implementation.
dnnl_status_t DNNL_API dnnl_primitive_desc_destroy(dnnl_primitive_desc_t primitive_desc)
Destroys a primitive descriptor.
const void * const_dnnl_op_desc_t
A pointer to any of the operation descriptors (constant variant).
Definition: dnnl_types.h:1522
const_dnnl_primitive_desc_t get_primitive_desc() const
Returns the C API primitive descriptor of the underlying C API primitive.
Definition: dnnl.hpp:368
dnnl_status_t DNNL_API dnnl_primitive_get_primitive_desc(const_dnnl_primitive_t primitive, const_dnnl_primitive_desc_t *primitive_desc)
Retrieves a constant reference to the primitive descriptor of a given primitive.
#define DNNL_ARG_DST_ITER_C
A special mnemonic for LSTM output recurrent cell state vector.
Definition: dnnl_types.h:2324
#define DNNL_ARG_SRC_ITER_C
A special mnemonic for RNN input recurrent cell state vector.
Definition: dnnl_types.h:2301
query
Primitive descriptor query specification.
Definition: dnnl.hpp:745
#define DNNL_ARG_FROM
A special mnemonic for reorder source argument.
Definition: dnnl_types.h:2289
dnnl_alg_kind_t
Kinds of algorithms.
Definition: dnnl_types.h:1107
dnnl_primitive_kind_t
Kinds of primitives.
Definition: dnnl_types.h:1053
dnnl_query_t
Primitive descriptor query specification.
Definition: dnnl_types.h:2514
dnnl_primitive_kind_t convert_to_c(primitive::kind akind)
Converts primitive kind enum value from C++ API to C API type.
Definition: dnnl.hpp:364
struct dnnl_primitive_desc * dnnl_primitive_desc_t
A primitive descriptor handle.
Definition: dnnl_types.h:2190
#define DNNL_ARG_WEIGHTS_PEEPHOLE
A special mnemonic for RNN weights applied to the peephole weights.
Definition: dnnl_types.h:2348
kind get_kind() const
Returns the kind of the primitive.
Definition: dnnl.hpp:375
#define DNNL_ARG_SRC_LAYER
A special mnemonic for RNN input vector.
Definition: dnnl_types.h:2286
dnnl_status_t DNNL_API dnnl_primitive_destroy(dnnl_primitive_t primitive)
Destroys a primitive.
#define DNNL_ARG_DIFF_WEIGHTS_LAYER
A special mnemonic for diff of RNN weights applied to the layer input.
Definition: dnnl_types.h:2422
dnnl_status_t DNNL_API dnnl_primitive_desc_iterator_create(dnnl_primitive_desc_iterator_t *iterator, const_dnnl_op_desc_t op_desc, const_dnnl_primitive_attr_t attr, dnnl_engine_t engine, const_dnnl_primitive_desc_t hint_forward_primitive_desc)
Creates a primitive descriptor iterator.
#define DNNL_ARG_DST_LAYER
A special mnemonic for RNN output vector. An alias for DNNL_ARG_DST_0.
Definition: dnnl_types.h:2312
dnnl_status_t DNNL_API dnnl_primitive_create(dnnl_primitive_t *primitive, const_dnnl_primitive_desc_t primitive_desc)
Creates a primitive.
#define DNNL_ARG_BIAS
Bias tensor argument.
Definition: dnnl_types.h:2357
normalization_flags
Flags for normalization primitives.
Definition: dnnl.hpp:615
#define DNNL_ARG_DIFF_DST_ITER
A special mnemonic for gradient (diff) of RNN input recurrent hidden state vector.
Definition: dnnl_types.h:2404
dnnl_prop_kind_t
Kinds of propagation.
Definition: dnnl_types.h:1026
dnnl_status_t DNNL_API dnnl_primitive_desc_clone(dnnl_primitive_desc_t *primitive_desc, const_dnnl_primitive_desc_t existing_primitive_desc)
Clones a primitive descriptor.
#define DNNL_ARG_SRC_ITER
A special mnemonic for RNN input recurrent hidden state vector.
Definition: dnnl_types.h:2295
dnnl_primitive_desc_t DNNL_API dnnl_primitive_desc_iterator_fetch(const_dnnl_primitive_desc_iterator_t iterator)
Fetches the current primitive descriptor from a primitive descriptor iterator.
#define DNNL_ARG_TO
A special mnemonic for reorder destination argument.
Definition: dnnl_types.h:2310
#define DNNL_ARG_DIFF_DST_LAYER
A special mnemonic for gradient (diff) of RNN output vector.
Definition: dnnl_types.h:2398
@ dnnl_fuse_norm_relu
Fuse with ReLU.
Definition: dnnl_types.h:1289
@ dnnl_normalization_flags_none
Use no normalization flags.
Definition: dnnl_types.h:1250
@ dnnl_use_scaleshift
Use scale and shift parameters.
Definition: dnnl_types.h:1276
@ dnnl_use_global_stats
Use global statistics.
Definition: dnnl_types.h:1263
@ batch_normalization_d
batch normalization descriptor
@ weights_md
weights memory descriptor desc
@ memory_consumption_s64
memory required for scratchpad (bytes)
@ shuffle_d
shuffle descriptor
@ deconvolution_d
deconvolution descriptor
@ impl_info_str
implementation name
@ diff_weights_md
weights gradient (diff) memory desc
@ workspace_md
workspace memory desc
@ reduction_d
reduction descriptor
@ eltwise_d
eltwise descriptor
@ matmul_d
matmul descriptor
@ rnn_d
rnn descriptor
@ softmax_d
softmax descriptor
@ num_of_outputs_s32
number of outputs expected
@ primitive_kind
primitive kind
@ dst_md
destination memory desc
@ scratchpad_engine
scratchpad engine
@ reorder_src_engine
reorder source engine
@ op_d
operation descriptor
@ layer_normalization_d
layer normalization descriptor
@ logsoftmax_d
logsoftmax descriptor
@ pooling_d
pooling descriptor
@ num_of_inputs_s32
number of inputs expected
@ diff_src_md
source gradient (diff) memory desc
@ src_md
source memory desc
@ scratchpad_md
scratchpad memory desc
@ reorder_dst_engine
reorder destination engine
@ engine
execution engine
@ convolution_d
convolution descriptor
@ time_estimate_f64
runtime estimation (seconds), unimplemented
@ binary_d
binary descriptor
@ diff_dst_md
destination gradient (diff) memory desc
@ exec_arg_md
memory desc of an execute argument
@ inner_product_d
inner product descriptor
@ lrn_d
lrn descriptor
@ undef
no query
@ resampling_d
resampling descriptor
@ dnnl_pooling_avg_exclude_padding
Average pooling exclude padding.
Definition: dnnl_types.h:1183
@ dnnl_eltwise_clip
Eltwise: clip.
Definition: dnnl_types.h:1153
@ dnnl_eltwise_tanh_use_dst_for_bwd
Eltwise: hyperbolic tangent non-linearity (tanh) (dst for backward)
Definition: dnnl_types.h:1167
@ dnnl_eltwise_logsigmoid
Eltwise: logsigmoid.
Definition: dnnl_types.h:1163
@ dnnl_pooling_avg
Average pooling (alias for dnnl_pooling_avg_exclude_padding)
Definition: dnnl_types.h:1185
@ dnnl_eltwise_gelu_tanh
Eltwise: gelu.
Definition: dnnl_types.h:1145
@ dnnl_resampling_linear
Linear Resampling Method.
Definition: dnnl_types.h:1219
@ dnnl_eltwise_sqrt
Eltwise: square root.
Definition: dnnl_types.h:1130
@ dnnl_binary_min
Binary min.
Definition: dnnl_types.h:1211
@ dnnl_reduction_norm_lp_sum
Reduction using lp norm.
Definition: dnnl_types.h:1233
@ dnnl_eltwise_abs
Eltwise: abs.
Definition: dnnl_types.h:1128
@ dnnl_reduction_norm_lp_power_p_max
Reduction using lp norm without final pth-root.
Definition: dnnl_types.h:1235
@ dnnl_reduction_min
Reduction using min.
Definition: dnnl_types.h:1223
@ dnnl_eltwise_sqrt_use_dst_for_bwd
Eltwise: square root (dst for backward)
Definition: dnnl_types.h:1171
@ dnnl_eltwise_exp
Eltwise: exponent.
Definition: dnnl_types.h:1140
@ dnnl_eltwise_square
Eltwise: square.
Definition: dnnl_types.h:1126
@ dnnl_eltwise_gelu
Eltwise: tanh-based gelu (alias for dnnl_eltwise_gelu_tanh)
Definition: dnnl_types.h:1147
@ dnnl_convolution_winograd
Winograd convolution.
Definition: dnnl_types.h:1112
@ dnnl_eltwise_clip_v2_use_dst_for_bwd
Eltwise: clip version 2 (dst for backward)
Definition: dnnl_types.h:1177
@ dnnl_lrn_across_channels
Local response normalization (LRN) across multiple channels.
Definition: dnnl_types.h:1187
@ dnnl_binary_sub
Binary sub.
Definition: dnnl_types.h:1215
@ dnnl_deconvolution_direct
Direct deconvolution.
Definition: dnnl_types.h:1116
@ dnnl_eltwise_relu
Eltwise: ReLU.
Definition: dnnl_types.h:1120
@ dnnl_convolution_auto
Convolution algorithm(either direct or Winograd) is chosen just in time.
Definition: dnnl_types.h:1114
@ dnnl_eltwise_swish
Eltwise: swish.
Definition: dnnl_types.h:1149
@ dnnl_vanilla_rnn
RNN cell.
Definition: dnnl_types.h:1191
@ dnnl_eltwise_gelu_erf
Eltwise: erf-based gelu.
Definition: dnnl_types.h:1159
@ dnnl_vanilla_lstm
LSTM cell.
Definition: dnnl_types.h:1193
@ dnnl_eltwise_elu
Eltwise: exponential linear unit (elu)
Definition: dnnl_types.h:1124
@ dnnl_vanilla_gru
GRU cell.
Definition: dnnl_types.h:1195
@ dnnl_lbr_gru
GRU cell with linear before reset.
Definition: dnnl_types.h:1203
@ dnnl_eltwise_tanh
Eltwise: hyperbolic tangent non-linearity (tanh)
Definition: dnnl_types.h:1122
@ dnnl_convolution_direct
Direct convolution.
Definition: dnnl_types.h:1110
@ dnnl_eltwise_soft_relu
Eltwise: soft_relu.
Definition: dnnl_types.h:1136
@ dnnl_eltwise_log
Eltwise: natural logarithm.
Definition: dnnl_types.h:1151
@ dnnl_eltwise_clip_v2
Eltwise: clip version 2.
Definition: dnnl_types.h:1155
@ dnnl_lrn_within_channel
LRN within a single channel.
Definition: dnnl_types.h:1189
@ dnnl_eltwise_elu_use_dst_for_bwd
Eltwise: exponential linear unit (elu) (dst for backward)
Definition: dnnl_types.h:1169
@ dnnl_deconvolution_winograd
Winograd deconvolution.
Definition: dnnl_types.h:1118
@ dnnl_reduction_mul
Reduction using mul.
Definition: dnnl_types.h:1227
@ dnnl_eltwise_pow
Eltwise: pow.
Definition: dnnl_types.h:1157
@ dnnl_eltwise_relu_use_dst_for_bwd
Eltwise: ReLU (dst for backward)
Definition: dnnl_types.h:1165
@ dnnl_reduction_max
Reduction using max.
Definition: dnnl_types.h:1221
@ dnnl_eltwise_logistic
Eltwise: logistic.
Definition: dnnl_types.h:1138
@ dnnl_pooling_avg_include_padding
Average pooling include padding.
Definition: dnnl_types.h:1181
@ dnnl_reduction_mean
Reduction using mean.
Definition: dnnl_types.h:1229
@ dnnl_pooling_max
Max pooling.
Definition: dnnl_types.h:1179
@ dnnl_eltwise_logistic_use_dst_for_bwd
Eltwise: logistic (dst for backward)
Definition: dnnl_types.h:1173
@ dnnl_binary_add
Binary add.
Definition: dnnl_types.h:1205
@ dnnl_binary_div
Binary div.
Definition: dnnl_types.h:1213
@ dnnl_reduction_norm_lp_max
Reduction using lp norm.
Definition: dnnl_types.h:1231
@ dnnl_reduction_norm_lp_power_p_sum
Reduction using lp norm without final pth-root.
Definition: dnnl_types.h:1237
@ dnnl_eltwise_round
Eltwise: round.
Definition: dnnl_types.h:1161
@ dnnl_binary_mul
Binary mul.
Definition: dnnl_types.h:1207
@ dnnl_reduction_sum
Reduction using sum.
Definition: dnnl_types.h:1225
@ dnnl_eltwise_exp_use_dst_for_bwd
Eltwise: exp (dst for backward)
Definition: dnnl_types.h:1175
@ dnnl_eltwise_bounded_relu
Eltwise: bounded_relu.
Definition: dnnl_types.h:1134
@ dnnl_eltwise_linear
Eltwise: linear.
Definition: dnnl_types.h:1132
@ dnnl_resampling_nearest
Nearest Neighbor Resampling Method.
Definition: dnnl_types.h:1217
@ dnnl_binary_max
Binary max.
Definition: dnnl_types.h:1209
@ dnnl_binary
A binary primitive.
Definition: dnnl_types.h:1087
@ dnnl_concat
A (out-of-place) concat primitive.
Definition: dnnl_types.h:1061
@ dnnl_reorder
A reorder primitive.
Definition: dnnl_types.h:1057
@ dnnl_convolution
A convolution primitive.
Definition: dnnl_types.h:1065
@ dnnl_inner_product
An inner product primitive.
Definition: dnnl_types.h:1081
@ dnnl_resampling
A resampling primitive.
Definition: dnnl_types.h:1093
@ dnnl_batch_normalization
A batch normalization primitive.
Definition: dnnl_types.h:1077
@ dnnl_undefined_primitive
Undefined primitive.
Definition: dnnl_types.h:1055
@ dnnl_sum
A sum primitive.
Definition: dnnl_types.h:1063
@ dnnl_pooling_v2
A pooling version 2 primitive (pooling with dilation support).
Definition: dnnl_types.h:1095
@ dnnl_layer_normalization
A layer normalization primitive.
Definition: dnnl_types.h:1079
@ dnnl_prelu
A PReLU primitive.
Definition: dnnl_types.h:1099
@ dnnl_eltwise
An element-wise primitive.
Definition: dnnl_types.h:1069
@ dnnl_matmul
A matrix multiplication primitive.
Definition: dnnl_types.h:1091
@ dnnl_shuffle
A shuffle primitive.
Definition: dnnl_types.h:1059
@ dnnl_logsoftmax
A logsoftmax primitive.
Definition: dnnl_types.h:1089
@ dnnl_pooling
A pooling primitive.
Definition: dnnl_types.h:1073
@ dnnl_deconvolution
A deconvolution primitive.
Definition: dnnl_types.h:1067
@ dnnl_softmax
A softmax primitive.
Definition: dnnl_types.h:1071
@ dnnl_rnn
A rnn primitive.
Definition: dnnl_types.h:1083
@ dnnl_reduction
A reduction primitive.
Definition: dnnl_types.h:1097
@ dnnl_lrn
An LRN primitive.
Definition: dnnl_types.h:1075
@ dnnl_query_resampling_d
resampling descriptor
Definition: dnnl_types.h:2557
@ dnnl_query_num_of_outputs_s32
number of outputs expected
Definition: dnnl_types.h:2521
@ dnnl_query_convolution_d
convolution descriptor
Definition: dnnl_types.h:2542
@ dnnl_query_weights_md
weights memory descriptor desc
Definition: dnnl_types.h:2566
@ dnnl_query_src_md
source memory desc
Definition: dnnl_types.h:2564
@ dnnl_query_softmax_d
softmax descriptor
Definition: dnnl_types.h:2546
@ dnnl_query_binary_d
binary descriptor
Definition: dnnl_types.h:2554
@ dnnl_query_workspace_md
workspace memory desc
Definition: dnnl_types.h:2570
@ dnnl_query_matmul_d
matrix multiplication (matmul) descriptor
Definition: dnnl_types.h:2556
@ dnnl_query_num_of_inputs_s32
number of inputs expected
Definition: dnnl_types.h:2520
@ dnnl_query_op_d
op descriptor
Definition: dnnl_types.h:2541
@ dnnl_query_diff_src_md
source gradient memory desc
Definition: dnnl_types.h:2565
@ dnnl_query_scratchpad_md
scratchpad memory desc
Definition: dnnl_types.h:2571
@ dnnl_query_shuffle_d
shuffle descriptor
Definition: dnnl_types.h:2544
@ dnnl_query_memory_consumption_s64
memory consumption – extra
Definition: dnnl_types.h:2524
@ dnnl_query_inner_product_d
inner product descriptor
Definition: dnnl_types.h:2551
@ dnnl_query_deconvolution_d
deconvolution descriptor
Definition: dnnl_types.h:2543
@ dnnl_query_primitive_kind
primitive kind
Definition: dnnl_types.h:2518
@ dnnl_query_batch_normalization_d
batch normalization descriptor
Definition: dnnl_types.h:2549
@ dnnl_query_impl_info_str
for creating scratchpad memory
Definition: dnnl_types.h:2532
@ dnnl_query_time_estimate_f64
runtime estimation (seconds)
Definition: dnnl_types.h:2523
@ dnnl_query_eltwise_d
eltwise descriptor
Definition: dnnl_types.h:2545
@ dnnl_query_diff_weights_md
weights grad. memory desc
Definition: dnnl_types.h:2567
@ dnnl_query_reduction_d
reduction descriptor
Definition: dnnl_types.h:2559
@ dnnl_query_reorder_dst_engine
destination engine
Definition: dnnl_types.h:2535
@ dnnl_query_reorder_src_engine
source engine
Definition: dnnl_types.h:2534
@ dnnl_query_scratchpad_engine
(scratch) memory, additional to all inputs and outputs memory (bytes)
Definition: dnnl_types.h:2529
@ dnnl_query_undef
no query
Definition: dnnl_types.h:2515
@ dnnl_query_prop_kind
propagation kind
Definition: dnnl_types.h:2537
@ dnnl_query_pooling_d
pooling descriptor
Definition: dnnl_types.h:2547
@ dnnl_query_exec_arg_md
memory desc of an execute argument
Definition: dnnl_types.h:2572
@ dnnl_query_engine
execution engine
Definition: dnnl_types.h:2517
@ dnnl_query_rnn_d
rnn descriptor
Definition: dnnl_types.h:2552
@ dnnl_query_layer_normalization_d
layer normalization descriptor
Definition: dnnl_types.h:2550
@ dnnl_query_lrn_d
lrn descriptor
Definition: dnnl_types.h:2548
@ dnnl_query_dst_md
destination memory desc
Definition: dnnl_types.h:2568
@ dnnl_query_diff_dst_md
destination grad. memory desc
Definition: dnnl_types.h:2569
@ dnnl_query_logsoftmax_d
logsoftmax descriptor
Definition: dnnl_types.h:2555
@ use_scale_shift
Use scale and shift parameters.
@ none
Use no normalization flags.
@ fuse_norm_relu
Fuse normalization with ReLU.
@ use_global_stats
Use global statistics.
@ dnnl_backward_weights
Backward weights propagation.
Definition: dnnl_types.h:1046
@ dnnl_forward_inference
Forward data propagation (inference mode).
Definition: dnnl_types.h:1036
@ dnnl_backward
Backward propagation (with respect to all parameters).
Definition: dnnl_types.h:1042
@ dnnl_backward_data
Backward data propagation.
Definition: dnnl_types.h:1044
@ dnnl_prop_kind_undef
Undefined propagation type.
Definition: dnnl_types.h:1029
@ dnnl_forward
Forward data propagation (alias for dnnl_forward_training).
Definition: dnnl_types.h:1040
@ dnnl_forward_training
Forward data propagation (training mode).
Definition: dnnl_types.h:1032
@ dnnl_backward_bias
Backward bias propagation.
Definition: dnnl_types.h:1048
@ dnnl_forward_scoring
Forward data propagation (alias for dnnl_forward_inference).
Definition: dnnl_types.h:1038
dnnl_status_t DNNL_API dnnl_reduction_desc_init(dnnl_reduction_desc_t *desc, dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *dst_desc, float p, float eps)
Initializes a descriptor for a reduction primitive.
dnnl_status_t DNNL_API dnnl_reorder_primitive_desc_create(dnnl_primitive_desc_t *reorder_primitive_desc, const dnnl_memory_desc_t *src_desc, dnnl_engine_t src_engine, const dnnl_memory_desc_t *dst_desc, dnnl_engine_t dst_engine, const_dnnl_primitive_attr_t attr)
Creates a primitive descriptor for a reorder primitive.
dnnl_status_t DNNL_API dnnl_resampling_backward_desc_init(dnnl_resampling_desc_t *resampling_desc, dnnl_alg_kind_t alg_kind, const float *factors, const dnnl_memory_desc_t *diff_src_desc, const dnnl_memory_desc_t *diff_dst_desc)
Initializes a descriptor for resampling backward propagation primitive.
dnnl_status_t DNNL_API dnnl_resampling_forward_desc_init(dnnl_resampling_desc_t *resampling_desc, dnnl_prop_kind_t prop_kind, dnnl_alg_kind_t alg_kind, const float *factors, const dnnl_memory_desc_t *src_desc, const dnnl_memory_desc_t *dst_desc)
Initializes a descriptor for a resampling forward propagation primitive.
dnnl_status_t DNNL_API dnnl_lbr_gru_backward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *diff_src_layer_desc, const dnnl_memory_desc_t *diff_src_iter_desc, const dnnl_memory_desc_t *diff_weights_layer_desc, const dnnl_memory_desc_t *diff_weights_iter_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_layer_desc, const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags)
Initializes a descriptor for LBR GRU backward propagation primitive.
dnnl_status_t DNNL_API dnnl_gru_forward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, unsigned flags)
Initializes a descriptor for GRU forward propagation primitive.
rnn_direction
A direction of RNN primitive execution.
Definition: dnnl.hpp:712
dnnl_rnn_flags_t
Flags for RNN cell.
Definition: dnnl_types.h:1931
dnnl_status_t DNNL_API dnnl_vanilla_rnn_forward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, const dnnl_alg_kind_t activation, const dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, unsigned flags, float alpha, float beta)
Initializes a descriptor for vanilla RNN forward propagation primitive.
dnnl_status_t DNNL_API dnnl_lstm_backward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *src_iter_c_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *dst_iter_c_desc, const dnnl_memory_desc_t *diff_src_layer_desc, const dnnl_memory_desc_t *diff_src_iter_desc, const dnnl_memory_desc_t *diff_src_iter_c_desc, const dnnl_memory_desc_t *diff_weights_layer_desc, const dnnl_memory_desc_t *diff_weights_iter_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_layer_desc, const dnnl_memory_desc_t *diff_dst_iter_desc, const dnnl_memory_desc_t *diff_dst_iter_c_desc, unsigned flags)
Initializes a descriptor for an LSTM backward propagation primitive.
dnnl_rnn_direction_t
A direction of RNN primitive execution.
Definition: dnnl_types.h:1937
dnnl_status_t DNNL_API dnnl_vanilla_rnn_backward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, const dnnl_alg_kind_t activation, const dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *diff_src_layer_desc, const dnnl_memory_desc_t *diff_src_iter_desc, const dnnl_memory_desc_t *diff_weights_layer_desc, const dnnl_memory_desc_t *diff_weights_iter_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_layer_desc, const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags, float alpha, float beta)
Initializes a descriptor for vanilla RNN backward propagation primitive.
dnnl_status_t DNNL_API dnnl_lstm_backward_desc_init_v3(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *src_iter_c_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *weights_peephole_desc, const dnnl_memory_desc_t *weights_projection_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *dst_iter_c_desc, const dnnl_memory_desc_t *diff_src_layer_desc, const dnnl_memory_desc_t *diff_src_iter_desc, const dnnl_memory_desc_t *diff_src_iter_c_desc, const dnnl_memory_desc_t *diff_weights_layer_desc, const dnnl_memory_desc_t *diff_weights_iter_desc, const dnnl_memory_desc_t *diff_weights_peephole_desc, const dnnl_memory_desc_t *diff_weights_projection_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_layer_desc, const dnnl_memory_desc_t *diff_dst_iter_desc, const dnnl_memory_desc_t *diff_dst_iter_c_desc, unsigned flags)
Initializes a descriptor for an LSTM (with or without peephole and with or with out recurrent project...
dnnl_status_t DNNL_API dnnl_lstm_backward_desc_init_v2(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *src_iter_c_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *weights_peephole_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *dst_iter_c_desc, const dnnl_memory_desc_t *diff_src_layer_desc, const dnnl_memory_desc_t *diff_src_iter_desc, const dnnl_memory_desc_t *diff_src_iter_c_desc, const dnnl_memory_desc_t *diff_weights_layer_desc, const dnnl_memory_desc_t *diff_weights_iter_desc, const dnnl_memory_desc_t *diff_weights_peephole_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_layer_desc, const dnnl_memory_desc_t *diff_dst_iter_desc, const dnnl_memory_desc_t *diff_dst_iter_c_desc, unsigned flags)
Initializes a descriptor for an LSTM (with or without peephole) backward propagation primitive.
dnnl_status_t DNNL_API dnnl_lstm_forward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *src_iter_c_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *dst_iter_c_desc, unsigned flags)
Initializes a descriptor for LSTM forward propagation primitive.
dnnl_status_t DNNL_API dnnl_lbr_gru_forward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, unsigned flags)
Initializes a descriptor for LBR GRU forward propagation primitive.
dnnl_status_t DNNL_API dnnl_lstm_forward_desc_init_v3(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *src_iter_c_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *weights_peephole_desc, const dnnl_memory_desc_t *weights_projection_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *dst_iter_c_desc, unsigned flags)
Initializes a descriptor for an LSTM (with or without peephole and with or without recurrent projecti...
rnn_flags
RNN cell flags.
Definition: dnnl.hpp:658
dnnl_status_t DNNL_API dnnl_lstm_forward_desc_init_v2(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *src_iter_c_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *weights_peephole_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *dst_iter_c_desc, unsigned flags)
Initializes a descriptor for an LSTM (with or without peephole) forward propagation primitive.
dnnl_status_t DNNL_API dnnl_gru_backward_desc_init(dnnl_rnn_desc_t *rnn_desc, dnnl_prop_kind_t prop_kind, dnnl_rnn_direction_t direction, const dnnl_memory_desc_t *src_layer_desc, const dnnl_memory_desc_t *src_iter_desc, const dnnl_memory_desc_t *weights_layer_desc, const dnnl_memory_desc_t *weights_iter_desc, const dnnl_memory_desc_t *bias_desc, const dnnl_memory_desc_t *dst_layer_desc, const dnnl_memory_desc_t *dst_iter_desc, const dnnl_memory_desc_t *diff_src_layer_desc, const dnnl_memory_desc_t *diff_src_iter_desc, const dnnl_memory_desc_t *diff_weights_layer_desc, const dnnl_memory_desc_t *diff_weights_iter_desc, const dnnl_memory_desc_t *diff_bias_desc, const dnnl_memory_desc_t *diff_dst_layer_desc, const dnnl_memory_desc_t *diff_dst_iter_desc, unsigned flags)
Initializes a descriptor for GRU backward propagation primitive.
@ unidirectional_left2right
Unidirectional execution of RNN primitive from left to right.
@ unidirectional_right2left
Unidirectional execution of RNN primitive from right to left.
@ bidirectional_concat
Bidirectional execution of RNN primitive with concatenation of the results.
@ unidirectional
Alias for dnnl::rnn_direction::unidirectional_left2right.
@ bidirectional_sum
Bidirectional execution of RNN primitive with summation of the results.
@ dnnl_rnn_flags_undef
Undefined RNN flags.
Definition: dnnl_types.h:1933
@ dnnl_unidirectional
Alias for dnnl_unidirectional_left2right.
Definition: dnnl_types.h:1949
@ dnnl_bidirectional_concat
Bidirectional execution of RNN primitive with concatenation of the results.
Definition: dnnl_types.h:1944
@ dnnl_bidirectional_sum
Bidirectional execution of RNN primitive with summation of the results.
Definition: dnnl_types.h:1947
@ dnnl_unidirectional_left2right
Unidirectional execution of RNN primitive from left to right.
Definition: dnnl_types.h:1939
@ dnnl_unidirectional_right2left
Unidirectional execution of RNN primitive from right to left.
Definition: dnnl_types.h:1941
@ undef
Undefined RNN flags.
dnnl_status_t DNNL_API dnnl_set_jit_dump(int enable)
Configures dumping of JIT-generated code.
status set_max_cpu_isa(cpu_isa isa)
Sets the maximal ISA the library can dispatch to on the CPU.
Definition: dnnl.hpp:10858
dnnl_status_t DNNL_API dnnl_set_verbose(int level)
Configures verbose output to stdout.
status set_jit_dump(int enable)
Configures dumping of JIT-generated code.
Definition: dnnl.hpp:10817
status set_cpu_isa_hints(cpu_isa_hints isa_hints)
Sets the hints flag for the CPU ISA.
Definition: dnnl.hpp:10877
dnnl_cpu_isa_t
CPU instruction set flags.
Definition: dnnl_types.h:2664
status set_verbose(int level)
Configures verbose output to stdout.
Definition: dnnl.hpp:10807
cpu_isa get_effective_cpu_isa()
Gets the maximal ISA the library can dispatch to on the CPU.
Definition: dnnl.hpp:10864
dnnl_status_t DNNL_API dnnl_set_max_cpu_isa(dnnl_cpu_isa_t isa)
Sets the maximal ISA the library can dispatch to on the CPU.
dnnl_status_t DNNL_API dnnl_set_jit_profiling_flags(unsigned flags)
Sets library profiling flags.
status set_jit_profiling_jitdumpdir(const std::string &dir)
Sets JIT dump output path.
Definition: dnnl.hpp:10827
const dnnl_version_t DNNL_API * dnnl_version(void)
Returns library version information.
status
Status values returned by the library functions.
Definition: dnnl.hpp:10789
cpu_isa_hints get_cpu_isa_hints()
Gets the ISA specific hints that library can follow.
Definition: dnnl.hpp:10883
status set_jit_profiling_flags(unsigned flags)
Sets library profiling flags.
Definition: dnnl.hpp:10822
const version_t * version()
Returns library version information.
Definition: dnnl.hpp:10812
cpu_isa
CPU instruction set flags.
Definition: dnnl.hpp:10832
dnnl_cpu_isa_t DNNL_API dnnl_get_effective_cpu_isa(void)
Gets the maximal ISA the library can dispatch to on the CPU.
dnnl_status_t DNNL_API dnnl_set_cpu_isa_hints(dnnl_cpu_isa_hints_t isa_hints)
Sets the hints flag for the CPU ISA.
dnnl_cpu_isa_hints_t DNNL_API dnnl_get_cpu_isa_hints(void)
Gets the ISA specific hints that library can follow.
dnnl_cpu_isa_hints_t
CPU ISA hints flags.
Definition: dnnl_types.h:2710
cpu_isa_hints
CPU ISA hints flags.
Definition: dnnl.hpp:10869
dnnl_status_t DNNL_API dnnl_set_jit_profiling_jitdumpdir(const char *dir)
Sets JIT dump output path.
@ dnnl_cpu_isa_avx512_mic
Intel Advanced Vector Extensions 512 (Intel AVX-512) subset for Intel Xeon Phi processors x200 Series...
Definition: dnnl_types.h:2679
@ dnnl_cpu_isa_avx
Intel Advanced Vector Extensions (Intel AVX)
Definition: dnnl_types.h:2672
@ dnnl_cpu_isa_avx512_core_amx
Intel AVX-512, Intel DL Boost and bfloat16 support and Intel AMX with 8-bit integer and bfloat16 supp...
Definition: dnnl_types.h:2702
@ dnnl_cpu_isa_avx512_core_vnni
Intel AVX-512 and Intel Deep Learning Boost (Intel DL Boost) support for Intel Xeon Scalable processo...
Definition: dnnl_types.h:2692
@ dnnl_cpu_isa_avx2
Intel Advanced Vector Extensions 2 (Intel AVX2)
Definition: dnnl_types.h:2675
@ dnnl_cpu_isa_all
Any ISA (excepting those listed as initial support)
Definition: dnnl_types.h:2666
@ dnnl_cpu_isa_avx512_core
Intel AVX-512 subset for Intel Xeon Scalable processor family and Intel Core processor family.
Definition: dnnl_types.h:2687
@ dnnl_cpu_isa_sse41
Intel Streaming SIMD Extensions 4.1 (Intel SSE4.1)
Definition: dnnl_types.h:2669
@ dnnl_cpu_isa_avx2_vnni
Intel AVX2 and Intel Deep Learning Boost (Intel DL Boost) support.
Definition: dnnl_types.h:2705
@ dnnl_cpu_isa_avx512_core_bf16
Intel AVX-512, Intel DL Boost and bfloat16 support for Intel Xeon Scalable processor family and Intel...
Definition: dnnl_types.h:2697
@ dnnl_cpu_isa_avx512_mic_4ops
Intel AVX-512 subset for Intel Xeon Phi processors 7235, 7285, 7295 Series.
Definition: dnnl_types.h:2683
@ not_required
Queried element is not required for given primitive.
@ invalid_arguments
The operation failed because of incorrect function arguments.
@ success
The operation was successful.
@ unimplemented
The operation failed because requested functionality is not implemented.
@ runtime_error
Primitive or engine failed on execution.
@ out_of_memory
The operation failed due to an out-of-memory condition.
@ iterator_ends
Primitive iterator passed over last primitive descriptor.
@ avx512_mic
Intel Advanced Vector Extensions 512 (Intel AVX-512) subset for Intel Xeon Phi processors x200 Series...
@ avx2
Intel Advanced Vector Extensions 2 (Intel AVX2)
@ avx2_vnni
Intel AVX2 and Intel Deep Learning Boost (Intel DL Boost) support.
@ avx
Intel Advanced Vector Extensions (Intel AVX)
@ all
Any ISA (excepting those listed as initial support)
@ avx512_core
Intel AVX-512 subset for Intel Xeon Scalable processor family and Intel Core processor family.
@ avx512_mic_4ops
Intel AVX-512 subset for Intel Xeon Phi processors 7235, 7285, 7295 Series.
@ sse41
Intel Streaming SIMD Extensions 4.1 (Intel SSE4.1)
@ avx512_core_vnni
Intel AVX-512 and Intel Deep Learning Boost (Intel DL Boost) support for Intel Xeon Scalable processo...
@ avx512_core_amx
Intel AVX-512, Intel DL Boost and bfloat16 support and Intel AMX with 8-bit integer and bfloat16 supp...
@ avx512_core_bf16
Intel AVX-512, Intel DL Boost and bfloat16 support for Intel Xeon Scalable processor family and Intel...
@ dnnl_cpu_isa_no_hints
No hints (use default features)
Definition: dnnl_types.h:2712
@ dnnl_cpu_isa_prefer_ymm
Prefer to exclusively use Ymm registers for computations.
Definition: dnnl_types.h:2715
@ no_hints
No hints (use default features)
@ prefer_ymm
Prefer to exclusively use Ymm registers for computations.
dnnl_status_t DNNL_API dnnl_shuffle_forward_desc_init(dnnl_shuffle_desc_t *shuffle_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *data_desc, int axis, dnnl_dim_t group_size)
Initializes a descriptor for shuffle forward propagation primitive.
dnnl_status_t DNNL_API dnnl_shuffle_backward_desc_init(dnnl_shuffle_desc_t *shuffle_desc, const dnnl_memory_desc_t *diff_data_desc, int axis, dnnl_dim_t group_size)
Initializes a descriptor for shuffle backward propagation primitive.
dnnl_status_t DNNL_API dnnl_softmax_backward_desc_init(dnnl_softmax_desc_t *softmax_desc, const dnnl_memory_desc_t *diff_data_desc, const dnnl_memory_desc_t *data_desc, int softmax_axis)
Initializes a descriptor for softmax backward propagation primitive.
dnnl_status_t DNNL_API dnnl_softmax_forward_desc_init(dnnl_softmax_desc_t *softmax_desc, dnnl_prop_kind_t prop_kind, const dnnl_memory_desc_t *data_desc, int softmax_axis)
Initializes a descriptor for softmax forward propagation primitive.
dnnl_stream_flags_t
Stream flags.
Definition: dnnl_types.h:2586
dnnl_status_t DNNL_API dnnl_stream_wait(dnnl_stream_t stream)
Waits for all primitives in the execution stream to finish computations.
dnnl_status_t DNNL_API dnnl_stream_get_engine(const_dnnl_stream_t stream, dnnl_engine_t *engine)
Returns the engine of a stream object.
dnnl_status_t DNNL_API dnnl_stream_destroy(dnnl_stream_t stream)
Destroys an execution stream.
dnnl_status_t DNNL_API dnnl_stream_create(dnnl_stream_t *stream, dnnl_engine_t engine, unsigned flags)
Creates an execution stream.
@ dnnl_stream_out_of_order
Out-of-order execution.
Definition: dnnl_types.h:2590
@ dnnl_stream_default_flags
Default stream configuration.
Definition: dnnl_types.h:2592
dnnl_status_t DNNL_API dnnl_sum_primitive_desc_create(dnnl_primitive_desc_t *sum_primitive_desc, const dnnl_memory_desc_t *dst_desc, int n, const float *scales, const dnnl_memory_desc_t *src_descs, const_dnnl_primitive_attr_t attr, dnnl_engine_t engine)
Creates a primitive descriptor for an (out-of-place) sum primitive.
dnnl_status_t
Status values returned by the library functions.
Definition: dnnl_types.h:39
@ dnnl_iterator_ends
Primitive iterator passed over last primitive descriptor.
Definition: dnnl_types.h:49
@ dnnl_runtime_error
Primitive or engine failed on execution.
Definition: dnnl_types.h:51
@ dnnl_unimplemented
The operation failed because requested functionality is not implemented.
Definition: dnnl_types.h:47
@ dnnl_out_of_memory
The operation failed due to an out-of-memory condition.
Definition: dnnl_types.h:43
@ dnnl_success
The operation was successful.
Definition: dnnl_types.h:41
@ dnnl_invalid_arguments
The operation failed because of incorrect function arguments.
Definition: dnnl_types.h:45
@ dnnl_not_required
Queried element is not required for given primitive.
Definition: dnnl_types.h:53
oneDNN namespace
Definition: dnnl.hpp:74
oneAPI namespace
Definition: dnnl.hpp:10981
C API.
Descriptor for a batch normalization backward propagation primitive.
Definition: dnnl.hpp:6605
desc(prop_kind aprop_kind, const memory::desc &diff_data_desc, const memory::desc &data_desc, float epsilon, normalization_flags flags)
Constructs a batch normalization descriptor for backward propagation.
Definition: dnnl.hpp:6620
Primitive descriptor for a batch normalization backward propagation primitive.
Definition: dnnl.hpp:6634
primitive_desc(const desc &adesc, const engine &aengine, const batch_normalization_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a batch normalization backward propagation primitive.
Definition: dnnl.hpp:6651
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition: dnnl.hpp:6694
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a batch normalization backward propagation primitive from a C A...
Definition: dnnl.hpp:6684
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:6719
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition: dnnl.hpp:6700
memory::desc variance_desc() const
Returns memory descriptor for variance.
Definition: dnnl.hpp:6714
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const batch_normalization_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a batch normalization backward propagation primitive.
Definition: dnnl.hpp:6671
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:6697
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:6691
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:6703
memory::desc diff_weights_desc() const
Returns a diff weights memory descriptor.
Definition: dnnl.hpp:6706
memory::desc mean_desc() const
Returns memory descriptor for mean.
Definition: dnnl.hpp:6711
Batch normalization backward propagation primitive.
Definition: dnnl.hpp:6603
batch_normalization_backward()=default
Default constructor. Produces an empty object.
batch_normalization_backward(const primitive_desc &pd)
Constructs a batch normalization backward propagation primitive.
Definition: dnnl.hpp:6728
Descriptor for a batch normalization forward propagation primitive.
Definition: dnnl.hpp:6476
desc(prop_kind aprop_kind, const memory::desc &data_desc, float epsilon, normalization_flags flags)
Constructs a batch normalization descriptor for forward propagation.
Definition: dnnl.hpp:6493
Primitive descriptor for a batch normalization forward propagation primitive.
Definition: dnnl.hpp:6506
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:6554
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition: dnnl.hpp:6560
memory::desc mean_desc() const
Returns memory descriptor for mean.
Definition: dnnl.hpp:6567
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a batch normalization forward propagation primitive.
Definition: dnnl.hpp:6520
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a batch normalization forward propagation primitive.
Definition: dnnl.hpp:6536
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:6563
memory::desc variance_desc() const
Returns memory descriptor for variance.
Definition: dnnl.hpp:6571
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a batch normalization forward propagation primitive from a C AP...
Definition: dnnl.hpp:6547
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:6557
Batch normalization forward propagation primitive.
Definition: dnnl.hpp:6474
batch_normalization_forward()=default
Default constructor. Produces an empty object.
batch_normalization_forward(const primitive_desc &pd)
Constructs a batch normalization forward propagation primitive.
Definition: dnnl.hpp:6599
Descriptor for an elementwise binary operator primitive.
Definition: dnnl.hpp:9754
desc()=default
Default constructor. Produces an empty object.
dnnl_binary_desc_t data
Underlying C operation descriptor.
Definition: dnnl.hpp:9756
desc(algorithm aalgorithm, const memory::desc &src0, const memory::desc &src1, const memory::desc &dst)
Constructs a descriptor for an elementwise binary operator primitive.
Definition: dnnl.hpp:9768
Primitive descriptor for an elementwise binary operator primitive.
Definition: dnnl.hpp:9779
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for an elementwise binary operator primitive.
Definition: dnnl.hpp:9807
memory::desc src_desc(int idx=0) const
Returns a source memory descriptor.
Definition: dnnl.hpp:9820
memory::desc src0_desc() const
Returns the memory descriptor for source #0.
Definition: dnnl.hpp:9823
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a binary primitive from a C API primitive descriptor that must ...
Definition: dnnl.hpp:9816
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:9829
memory::desc src1_desc() const
Returns the memory descriptor for source #1.
Definition: dnnl.hpp:9826
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for an elementwise binary operator primitive.
Definition: dnnl.hpp:9792
Elementwise binary operator primitive.
Definition: dnnl.hpp:9752
binary()=default
Default constructor. Produces an empty object.
binary(const primitive_desc &pd)
Constructs an elementwise binary operation primitive.
Definition: dnnl.hpp:9838
Primitive descriptor for a concat primitive.
Definition: dnnl.hpp:3695
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:3764
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for concat primitive from a C API primitive descriptor which must h...
Definition: dnnl.hpp:3757
primitive_desc(const memory::desc &dst, int concat_dimension, const std::vector< memory::desc > &srcs, const engine &aengine, const primitive_attr &attr=primitive_attr())
Constructs a primitive descriptor for an out-of-place concatenation primitive.
Definition: dnnl.hpp:3711
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(int concat_dimension, const std::vector< memory::desc > &srcs, const engine &aengine, const primitive_attr &attr=primitive_attr())
Constructs a primitive descriptor for an out-of-place concatenation primitive.
Definition: dnnl.hpp:3738
memory::desc src_desc(int idx=0) const
Returns a source memory descriptor.
Definition: dnnl.hpp:3761
Tensor concatenation (concat) primitive.
Definition: dnnl.hpp:3693
concat()=default
Default constructor. Produces an empty object.
concat(const primitive_desc &pd)
Constructs a concatenation primitive.
Definition: dnnl.hpp:3772
Descriptor for a convolution backward propagation primitive.
Definition: dnnl.hpp:4236
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for dilated convolution backward propagation primitive.
Definition: dnnl.hpp:4307
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a convolution backward propagation primitive.
Definition: dnnl.hpp:4264
Primitive descriptor for a convolution backward propagation primitive.
Definition: dnnl.hpp:4328
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition: dnnl.hpp:4386
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:4389
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a convolution backward propagation primitive from a C API primi...
Definition: dnnl.hpp:4378
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &adesc, const engine &aengine, const convolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a convolution backward propagation primitive.
Definition: dnnl.hpp:4345
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition: dnnl.hpp:4383
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const convolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a convolution backward propagation primitive.
Definition: dnnl.hpp:4365
Convolution backward propagation primitive.
Definition: dnnl.hpp:4233
convolution_backward_data()=default
Default constructor. Produces an empty object.
convolution_backward_data(const primitive_desc &pd)
Constructs a convolution backward propagation primitive.
Definition: dnnl.hpp:4398
Descriptor for a convolution weights gradient primitive.
Definition: dnnl.hpp:4404
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a convolution weights gradient primitive without bias.
Definition: dnnl.hpp:4477
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a dilated convolution weights gradient primitive with bias.
Definition: dnnl.hpp:4522
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a dilated convolution weights gradient primitive without bias.
Definition: dnnl.hpp:4569
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a convolution weights gradient primitive with bias.
Definition: dnnl.hpp:4434
Primitive descriptor for a convolution weights gradient primitive.
Definition: dnnl.hpp:4590
memory::desc diff_bias_desc() const
Returns the diff bias memory descriptor.
Definition: dnnl.hpp:4657
memory::desc diff_weights_desc() const
Returns a diff weights memory descriptor.
Definition: dnnl.hpp:4646
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const convolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a convolution weights gradient primitive.
Definition: dnnl.hpp:4625
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a convolution weights gradient primitive from a C API primitive...
Definition: dnnl.hpp:4638
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:4643
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &adesc, const engine &aengine, const convolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a convolution weights gradient primitive.
Definition: dnnl.hpp:4606
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:4651
Convolution weights gradient primitive.
Definition: dnnl.hpp:4402
convolution_backward_weights()=default
Default constructor. Produces an empty object.
convolution_backward_weights(const primitive_desc &pd)
Constructs a convolution weights gradient primitive.
Definition: dnnl.hpp:4668
Descriptor for a convolution forward propagation primitive.
Definition: dnnl.hpp:3963
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a dilated convolution forward propagation primitive with bias.
Definition: dnnl.hpp:4091
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a convolution forward propagation primitive without bias.
Definition: dnnl.hpp:4042
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a dilated convolution forward propagation primitive without bias.
Definition: dnnl.hpp:4140
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a convolution forward propagation primitive with bias.
Definition: dnnl.hpp:3996
Primitive descriptor for a convolution forward propagation primitive.
Definition: dnnl.hpp:4161
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a convolution forward propagation primitive.
Definition: dnnl.hpp:4175
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a convolution forward propagation primitive.
Definition: dnnl.hpp:4191
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:4208
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a convolution forward propagation primitive from a C API primit...
Definition: dnnl.hpp:4202
memory::desc bias_desc() const
Returns the bias memory descriptor.
Definition: dnnl.hpp:4220
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition: dnnl.hpp:4211
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:4214
Convolution forward propagation primitive.
Definition: dnnl.hpp:3961
convolution_forward(const primitive_desc &pd)
Constructs a convolution forward propagation primitive.
Definition: dnnl.hpp:4229
convolution_forward()=default
Default constructor. Produces an empty object.
Descriptor for a deconvolution backward propagation primitive.
Definition: dnnl.hpp:4949
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a dilated deconvolution backward propagation primitive.
Definition: dnnl.hpp:5018
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a deconvolution backward propagation primitive.
Definition: dnnl.hpp:4976
Primitive descriptor for a deconvolution backward propagation primitive.
Definition: dnnl.hpp:5039
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const deconvolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a deconvolution backward propagation primitive.
Definition: dnnl.hpp:5076
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition: dnnl.hpp:5097
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:5100
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition: dnnl.hpp:5094
primitive_desc(const desc &adesc, const engine &aengine, const deconvolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a deconvolution backward propagation primitive.
Definition: dnnl.hpp:5056
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a deconvolution backward propagation primitive from a C API pri...
Definition: dnnl.hpp:5089
Deconvolution backward propagation primitive.
Definition: dnnl.hpp:4947
deconvolution_backward_data()=default
Default constructor. Produces an empty object.
deconvolution_backward_data(const primitive_desc &pd)
Constructs a deconvolution backward propagation primitive.
Definition: dnnl.hpp:5109
Descriptor for a deconvolution weights gradient primitive.
Definition: dnnl.hpp:5115
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a dilated deconvolution weights gradient primitive without bias.
Definition: dnnl.hpp:5276
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a dilated deconvolution weights gradient primitive with bias.
Definition: dnnl.hpp:5230
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a deconvolution weights gradient primitive without bias.
Definition: dnnl.hpp:5186
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a deconvolution weights gradient primitive with bias.
Definition: dnnl.hpp:5144
Primitive descriptor for a deconvolution weights gradient primitive.
Definition: dnnl.hpp:5297
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:5352
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:5360
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a deconvolution weights gradient primitive from a C API primiti...
Definition: dnnl.hpp:5347
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const deconvolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a deconvolution weights update primitive.
Definition: dnnl.hpp:5334
primitive_desc(const desc &adesc, const engine &aengine, const deconvolution_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a deconvolution weights update primitive.
Definition: dnnl.hpp:5314
memory::desc diff_weights_desc() const
Returns a diff weights memory descriptor.
Definition: dnnl.hpp:5355
memory::desc diff_bias_desc() const
Returns the diff bias memory descriptor.
Definition: dnnl.hpp:5363
primitive_desc()=default
Default constructor. Produces an empty object.
Deconvolution weights gradient primitive.
Definition: dnnl.hpp:5113
deconvolution_backward_weights()=default
Default constructor. Produces an empty object.
deconvolution_backward_weights(const primitive_desc &pd)
Constructs a deconvolution weights gradient primitive.
Definition: dnnl.hpp:5374
Descriptor for a deconvolution forward propagation primitive.
Definition: dnnl.hpp:4684
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a dilated deconvolution forward propagation primitive with bias.
Definition: dnnl.hpp:4809
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a deconvolution forward propagation primitive without bias.
Definition: dnnl.hpp:4761
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a deconvolution forward propagation primitive with bias.
Definition: dnnl.hpp:4716
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &dilates, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for a dilated deconvolution forward propagation primitive without bias.
Definition: dnnl.hpp:4857
Primitive descriptor for a deconvolution forward propagation primitive.
Definition: dnnl.hpp:4878
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a deconvolution forward propagation primitive from a C API prim...
Definition: dnnl.hpp:4919
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:4931
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:4925
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a deconvolution forward propagation primitive.
Definition: dnnl.hpp:4908
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a deconvolution forward propagation primitive.
Definition: dnnl.hpp:4892
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc bias_desc() const
Returns the bias memory descriptor.
Definition: dnnl.hpp:4934
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition: dnnl.hpp:4928
Deconvolution forward propagation primitive.
Definition: dnnl.hpp:4682
deconvolution_forward(const primitive_desc &pd)
Constructs a deconvolution forward propagation primitive.
Definition: dnnl.hpp:4943
deconvolution_forward()=default
Default constructor. Produces an empty object.
Descriptor for an elementwise backward propagation primitive.
Definition: dnnl.hpp:5943
desc(algorithm aalgorithm, const memory::desc &diff_data_desc, const memory::desc &data_desc, float alpha=0, float beta=0)
Constructs a descriptor for an elementwise backward propagation primitive.
Definition: dnnl.hpp:5957
Primitive descriptor for eltwise backward propagation.
Definition: dnnl.hpp:5970
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition: dnnl.hpp:6028
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &adesc, const engine &aengine, const eltwise_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an elementwise backward propagation primitive.
Definition: dnnl.hpp:5987
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const eltwise_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an elementwise backward propagation primitive.
Definition: dnnl.hpp:6007
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:6025
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:6031
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for an eltwise backward propagation primitive from a C API primitiv...
Definition: dnnl.hpp:6020
Elementwise unary operation backward propagation primitive.
Definition: dnnl.hpp:5941
eltwise_backward()=default
Default constructor. Produces an empty object.
eltwise_backward(const primitive_desc &pd)
Constructs an eltwise backward propagation primitive.
Definition: dnnl.hpp:6040
Descriptor for an elementwise forward propagation primitive.
Definition: dnnl.hpp:5850
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &data_desc, float alpha=0, float beta=0)
Constructs a descriptor for an elementwise forward propagation primitive.
Definition: dnnl.hpp:5865
Primitive descriptor for an elementwise forward propagation primitive.
Definition: dnnl.hpp:5878
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for an elementwise forward propagation primitive.
Definition: dnnl.hpp:5908
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:5928
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:5925
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for an elementwise forward propagation primitive.
Definition: dnnl.hpp:5892
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for an eltwise forward propagation primitive from a C API primitive...
Definition: dnnl.hpp:5919
Elementwise unary operation forward propagation primitive.
Definition: dnnl.hpp:5848
eltwise_forward(const primitive_desc &pd)
Constructs an eltwise forward propagation primitive.
Definition: dnnl.hpp:5937
eltwise_forward()=default
Default constructor. Produces an empty object.
An execution engine.
Definition: dnnl.hpp:869
static engine query(const primitive_desc &pd)
Returns the engine of a primitive descriptor.
Definition: dnnl.hpp:938
kind
Kinds of engines.
Definition: dnnl.hpp:874
@ gpu
GPU engine.
@ any
An unspecified engine.
@ cpu
CPU engine.
engine(kind akind, size_t index)
Constructs an engine.
Definition: dnnl.hpp:902
engine()=default
Constructs an empty engine.
static size_t get_count(kind akind)
Returns the number of engines of a certain kind.
Definition: dnnl.hpp:893
engine(const handle< dnnl_primitive_desc_t > &pd)
Constructs an engine based on a primitive from the primitive descriptor pd by querying its engine.
Definition: dnnl.hpp:914
kind get_kind() const
Returns the kind of the engine.
Definition: dnnl.hpp:925
oneDNN exception class.
Definition: dnnl.hpp:84
error(dnnl_status_t status, const char *message)
Constructs an instance of an exception class.
Definition: dnnl.hpp:92
static void wrap_c_api(dnnl_status_t status, const char *message)
A convenience function for wrapping calls to C API functions.
Definition: dnnl.hpp:103
const char * what() const noexcept override
Returns the explanatory string.
Definition: dnnl.hpp:96
Descriptor for a GRU backward propagation primitive.
Definition: dnnl.hpp:9001
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc, rnn_flags flags=rnn_flags::undef)
Constructs a descriptor for a GRU backward propagation primitive.
Definition: dnnl.hpp:9048
Primitive descriptor for a GRU backward propagation primitive.
Definition: dnnl.hpp:9082
primitive_desc(const desc &adesc, const engine &aengine, const gru_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a GRU backward propagation primitive.
Definition: dnnl.hpp:9098
memory::desc diff_weights_iter_desc() const
Returns diff weights iteration memory descriptor.
Definition: dnnl.hpp:9184
memory::desc dst_layer_desc() const
Returns destination layer memory descriptor.
Definition: dnnl.hpp:9156
memory::desc weights_layer_desc() const
Returns weights layer memory descriptor.
Definition: dnnl.hpp:9143
memory::desc src_iter_desc() const
Returns source iteration memory descriptor.
Definition: dnnl.hpp:9140
memory::desc diff_bias_desc() const
Returns diff bias memory descriptor.
Definition: dnnl.hpp:9189
memory::desc weights_iter_desc() const
Returns weights iteration memory descriptor.
Definition: dnnl.hpp:9148
memory::desc bias_desc() const
Returns bias memory descriptor.
Definition: dnnl.hpp:9153
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc diff_dst_iter_desc() const
Returns diff destination iteration memory descriptor.
Definition: dnnl.hpp:9199
memory::desc diff_dst_layer_desc() const
Returns diff destination layer memory descriptor.
Definition: dnnl.hpp:9194
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a GRU backward propagation primitive from a C API primitive des...
Definition: dnnl.hpp:9130
memory::desc src_layer_desc() const
Returns source layer memory descriptor.
Definition: dnnl.hpp:9135
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:9164
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const gru_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a GRU backward propagation primitive.
Definition: dnnl.hpp:9117
memory::desc diff_src_layer_desc() const
Returns diff source layer memory descriptor.
Definition: dnnl.hpp:9169
memory::desc diff_src_iter_desc() const
Returns diff source iteration memory descriptor.
Definition: dnnl.hpp:9174
memory::desc diff_weights_layer_desc() const
Returns diff weights layer memory descriptor.
Definition: dnnl.hpp:9179
memory::desc dst_iter_desc() const
Returns destination iteration memory descriptor.
Definition: dnnl.hpp:9161
GRU backward propagation primitive.
Definition: dnnl.hpp:8999
gru_backward()=default
Default constructor. Produces an empty object.
gru_backward(const primitive_desc &pd)
Constructs a GRU backward propagation primitive.
Definition: dnnl.hpp:9210
Descriptor for a GRU forward propagation primitive.
Definition: dnnl.hpp:8852
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, rnn_flags flags=rnn_flags::undef)
Constructs a descriptor for a GRU forward propagation primitive.
Definition: dnnl.hpp:8887
Primitive descriptor for a GRU forward propagation primitive.
Definition: dnnl.hpp:8910
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a GRU forward propagation primitive.
Definition: dnnl.hpp:8923
memory::desc weights_iter_desc() const
Returns weights iteration memory descriptor.
Definition: dnnl.hpp:8968
memory::desc src_layer_desc() const
Returns source layer memory descriptor.
Definition: dnnl.hpp:8955
memory::desc dst_layer_desc() const
Returns destination layer memory descriptor.
Definition: dnnl.hpp:8976
memory::desc weights_layer_desc() const
Returns weights layer memory descriptor.
Definition: dnnl.hpp:8963
memory::desc bias_desc() const
Returns bias memory descriptor.
Definition: dnnl.hpp:8973
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc dst_iter_desc() const
Returns destination iteration memory descriptor.
Definition: dnnl.hpp:8981
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:8984
memory::desc src_iter_desc() const
Returns source iteration memory descriptor.
Definition: dnnl.hpp:8960
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a GRU forward propagation primitive from a C API primitive desc...
Definition: dnnl.hpp:8949
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a GRU forward propagation primitive.
Definition: dnnl.hpp:8938
GRU forward propagation primitive.
Definition: dnnl.hpp:8850
gru_forward(const primitive_desc &pd)
Constructs a GRU forward propagation primitive.
Definition: dnnl.hpp:8995
gru_forward()=default
Default constructor. Produces an empty object.
A class that provides the destructor for a oneDNN C API handle.
Definition: dnnl.hpp:120
oneDNN C API handle wrapper class.
Definition: dnnl.hpp:136
handle(const handle< T, traits > &)=default
Copy constructor.
bool operator==(const handle< T, traits > &other) const
Equality operator.
Definition: dnnl.hpp:210
bool operator!=(const handle &other) const
Inequality operator.
Definition: dnnl.hpp:220
T get(bool allow_empty=false) const
Returns the underlying C API handle.
Definition: dnnl.hpp:185
handle< T, traits > & operator=(const handle< T, traits > &)=default
Assignment operator.
handle()=default
Constructs an empty handle object.
void reset(T t, bool weak=false)
Resets the handle wrapper objects to wrap a new C API handle.
Definition: dnnl.hpp:176
handle(T t, bool weak=false)
Constructs a handle wrapper object from a C API handle.
Definition: dnnl.hpp:169
handle(handle< T, traits > &&)=default
Move constructor.
handle< T, traits > & operator=(handle< T, traits > &&)=default
Move assignment operator.
Descriptor for an inner product backward propagation primitive.
Definition: dnnl.hpp:7192
desc(const memory::desc &diff_src_desc, const memory::desc &weights_desc, const memory::desc &diff_dst_desc)
Constructs a descriptor for an inner product backward propagation primitive.
Definition: dnnl.hpp:7205
Primitive descriptor for an inner product backward propagation primitive.
Definition: dnnl.hpp:7218
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:7279
primitive_desc(const desc &adesc, const engine &aengine, const inner_product_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an inner product backward propagation primitive.
Definition: dnnl.hpp:7235
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition: dnnl.hpp:7276
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for an inner product backward propagation primitive from a C API pr...
Definition: dnnl.hpp:7268
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const inner_product_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an inner product backward propagation primitive.
Definition: dnnl.hpp:7255
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition: dnnl.hpp:7273
primitive_desc()=default
Default constructor. Produces an empty object.
Inner product backward propagation primitive.
Definition: dnnl.hpp:7190
inner_product_backward_data(const primitive_desc &pd)
Constructs an inner product backward propagation primitive.
Definition: dnnl.hpp:7288
inner_product_backward_data()=default
Default constructor. Produces an empty object.
Descriptor for an inner product weights gradient primitive.
Definition: dnnl.hpp:7294
desc(const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_desc)
Constructs a descriptor for an inner product descriptor weights update primitive with bias.
Definition: dnnl.hpp:7308
desc(const memory::desc &src_desc, const memory::desc &diff_weights_desc, const memory::desc &diff_dst_desc)
Constructs a descriptor for an inner product descriptor weights update primitive without bias.
Definition: dnnl.hpp:7330
Primitive descriptor for an inner product weights gradient primitive.
Definition: dnnl.hpp:7343
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:7398
memory::desc diff_weights_desc() const
Returns a diff weights memory descriptor.
Definition: dnnl.hpp:7401
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:7406
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for an inner product weights update primitive from a C API primitiv...
Definition: dnnl.hpp:7393
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const inner_product_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an inner product weights update primitive.
Definition: dnnl.hpp:7380
memory::desc diff_bias_desc() const
Returns the diff bias memory descriptor.
Definition: dnnl.hpp:7409
primitive_desc(const desc &adesc, const engine &aengine, const inner_product_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an inner product weights update primitive.
Definition: dnnl.hpp:7360
Inner product weights gradient primitive.
Definition: dnnl.hpp:7292
inner_product_backward_weights(const primitive_desc &pd)
Constructs an inner product weights gradient primitive.
Definition: dnnl.hpp:7420
inner_product_backward_weights()=default
Default constructor. Produces an empty object.
Descriptor for an inner product forward propagation primitive.
Definition: dnnl.hpp:7067
desc(prop_kind aprop_kind, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc)
Constructs a descriptor for an inner product forward propagation primitive without bias.
Definition: dnnl.hpp:7108
desc(prop_kind aprop_kind, const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc)
Constructs a descriptor for an inner product forward propagation primitive with bias.
Definition: dnnl.hpp:7084
Primitive descriptor for an inner product forward propagation primitive.
Definition: dnnl.hpp:7121
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for an inner product forward propagation primitive from a C API pri...
Definition: dnnl.hpp:7162
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:7174
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for an inner product forward propagation primitive.
Definition: dnnl.hpp:7135
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition: dnnl.hpp:7171
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc bias_desc() const
Returns the bias memory descriptor.
Definition: dnnl.hpp:7177
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:7168
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for an inner product forward propagation primitive.
Definition: dnnl.hpp:7151
Inner product forward propagation primitive.
Definition: dnnl.hpp:7065
inner_product_forward(const primitive_desc &pd)
Constructs an inner product forward propagation primitive.
Definition: dnnl.hpp:7186
inner_product_forward()=default
Default constructor. Produces an empty object.
Descriptor for a layer normalization backward propagation primitive.
Definition: dnnl.hpp:6903
desc(prop_kind aprop_kind, const memory::desc &diff_data_desc, const memory::desc &data_desc, float epsilon, normalization_flags flags)
Constructs a descriptor for layer normalization backward propagation primitive.
Definition: dnnl.hpp:6943
desc(prop_kind aprop_kind, const memory::desc &diff_data_desc, const memory::desc &data_desc, const memory::desc &stat_desc, float epsilon, normalization_flags flags)
Constructs a descriptor for layer normalization backward propagation primitive.
Definition: dnnl.hpp:6919
Primitive descriptor for a layer normalization backward propagation primitive.
Definition: dnnl.hpp:6957
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition: dnnl.hpp:7023
memory::desc mean_desc() const
Returns memory descriptor for mean.
Definition: dnnl.hpp:7034
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:7042
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a layer normalization backward propagation primitive from a C A...
Definition: dnnl.hpp:7007
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:7026
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:7020
memory::desc variance_desc() const
Returns memory descriptor for variance.
Definition: dnnl.hpp:7037
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition: dnnl.hpp:7017
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const layer_normalization_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a layer normalization backward propagation primitive.
Definition: dnnl.hpp:6994
memory::desc diff_weights_desc() const
Returns a diff weights memory descriptor.
Definition: dnnl.hpp:7029
primitive_desc(const desc &adesc, const engine &aengine, const layer_normalization_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a layer normalization backward propagation primitive.
Definition: dnnl.hpp:6974
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:7014
Layer normalization backward propagation primitive.
Definition: dnnl.hpp:6901
layer_normalization_backward(const primitive_desc &pd)
Constructs a layer normalization backward propagation primitive.
Definition: dnnl.hpp:7051
layer_normalization_backward()=default
Default constructor. Produces an empty object.
Descriptor for a layer normalization forward propagation primitive.
Definition: dnnl.hpp:6758
desc(prop_kind aprop_kind, const memory::desc &data_desc, const memory::desc &stat_desc, float epsilon, normalization_flags flags)
Constructs a descriptor for layer normalization forward propagation primitive.
Definition: dnnl.hpp:6772
desc(prop_kind aprop_kind, const memory::desc &data_desc, float epsilon, normalization_flags flags)
Constructs a descriptor for layer normalization forward propagation primitive.
Definition: dnnl.hpp:6793
Primitive descriptor for a layer normalization forward propagation primitive.
Definition: dnnl.hpp:6806
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a layer normalization forward propagation primitive.
Definition: dnnl.hpp:6820
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:6857
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:6854
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:6863
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a layer normalization forward propagation primitive from a C AP...
Definition: dnnl.hpp:6847
memory::desc variance_desc() const
Returns memory descriptor for variance.
Definition: dnnl.hpp:6869
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a layer normalization forward propagation primitive.
Definition: dnnl.hpp:6836
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition: dnnl.hpp:6860
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc mean_desc() const
Returns memory descriptor for mean.
Definition: dnnl.hpp:6866
Layer normalization forward propagation primitive.
Definition: dnnl.hpp:6756
layer_normalization_forward()=default
Default constructor. Produces an empty object.
layer_normalization_forward(const primitive_desc &pd)
Constructs a layer normalization forward propagation primitive.
Definition: dnnl.hpp:6897
Descriptor for a LBR GRU backward propagation primitive.
Definition: dnnl.hpp:9368
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc, rnn_flags flags=rnn_flags::undef)
Constructs a descriptor for LBR GRU backward propagation primitive.
Definition: dnnl.hpp:9416
Primitive descriptor for an LBR GRU backward propagation primitive.
Definition: dnnl.hpp:9450
memory::desc weights_layer_desc() const
Returns weights layer memory descriptor.
Definition: dnnl.hpp:9513
memory::desc diff_weights_layer_desc() const
Returns diff weights layer memory descriptor.
Definition: dnnl.hpp:9549
primitive_desc(const desc &adesc, const engine &aengine, const lbr_gru_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an LBR GRU backward propagation primitive.
Definition: dnnl.hpp:9467
memory::desc diff_dst_iter_desc() const
Returns diff destination iteration memory descriptor.
Definition: dnnl.hpp:9569
memory::desc diff_bias_desc() const
Returns diff bias memory descriptor.
Definition: dnnl.hpp:9559
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const lbr_gru_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an LBR GRU backward propagation primitive.
Definition: dnnl.hpp:9487
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc dst_iter_desc() const
Returns destination iteration memory descriptor.
Definition: dnnl.hpp:9531
memory::desc weights_iter_desc() const
Returns weights iteration memory descriptor.
Definition: dnnl.hpp:9518
memory::desc src_iter_desc() const
Returns source iteration memory descriptor.
Definition: dnnl.hpp:9510
memory::desc diff_src_iter_desc() const
Returns diff source iteration memory descriptor.
Definition: dnnl.hpp:9544
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a LBR GRU backward propagation primitive from a C API primitive...
Definition: dnnl.hpp:9500
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:9534
memory::desc bias_desc() const
Returns bias memory descriptor.
Definition: dnnl.hpp:9523
memory::desc dst_layer_desc() const
Returns destination layer memory descriptor.
Definition: dnnl.hpp:9526
memory::desc src_layer_desc() const
Returns source layer memory descriptor.
Definition: dnnl.hpp:9505
memory::desc diff_weights_iter_desc() const
Returns diff weights iteration memory descriptor.
Definition: dnnl.hpp:9554
memory::desc diff_dst_layer_desc() const
Returns diff destination layer memory descriptor.
Definition: dnnl.hpp:9564
memory::desc diff_src_layer_desc() const
Returns diff source layer memory descriptor.
Definition: dnnl.hpp:9539
LBR GRU backward propagation primitive.
Definition: dnnl.hpp:9366
lbr_gru_backward(const primitive_desc &pd)
Constructs an LBR GRU backward propagation primitive.
Definition: dnnl.hpp:9580
lbr_gru_backward()=default
Default constructor. Produces an empty object.
Descriptor for an LBR GRU forward propagation primitive.
Definition: dnnl.hpp:9216
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, rnn_flags flags=rnn_flags::undef)
Constructs a descriptor for LBR GRU forward propagation primitive.
Definition: dnnl.hpp:9252
Primitive descriptor for an LBR GRU forward propagation primitive.
Definition: dnnl.hpp:9275
memory::desc dst_iter_desc() const
Returns destination iteration memory descriptor.
Definition: dnnl.hpp:9348
memory::desc src_iter_desc() const
Returns source iteration memory descriptor.
Definition: dnnl.hpp:9327
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a LBR GRU forward propagation primitive.
Definition: dnnl.hpp:9305
memory::desc dst_layer_desc() const
Returns destination layer memory descriptor.
Definition: dnnl.hpp:9343
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:9351
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a LBR GRU forward propagation primitive from a C API primitive ...
Definition: dnnl.hpp:9316
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a LBR GRU forward propagation primitive.
Definition: dnnl.hpp:9289
memory::desc bias_desc() const
Returns bias memory descriptor.
Definition: dnnl.hpp:9340
memory::desc src_layer_desc() const
Returns source layer memory descriptor.
Definition: dnnl.hpp:9322
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc weights_iter_desc() const
Returns weights iteration memory descriptor.
Definition: dnnl.hpp:9335
memory::desc weights_layer_desc() const
Returns weights layer memory descriptor.
Definition: dnnl.hpp:9330
LBR GRU forward propagation primitive.
Definition: dnnl.hpp:9214
lbr_gru_forward()=default
Default constructor. Produces an empty object.
lbr_gru_forward(const primitive_desc &pd)
Constructs an LBR GRU forward propagation primitive.
Definition: dnnl.hpp:9362
Descriptor for a logsoftmax backward propagation primitive.
Definition: dnnl.hpp:6350
desc()=default
Default constructor. Produces an empty object.
desc(const memory::desc &diff_data_desc, const memory::desc &data_desc, int logsoftmax_axis)
Constructs a descriptor for a logsoftmax backward propagation primitive.
Definition: dnnl.hpp:6363
Primitive descriptor for a logsoftmax backward propagation primitive.
Definition: dnnl.hpp:6374
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:6433
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc diff_dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:6439
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition: dnnl.hpp:6436
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a logsoftmax backward propagation primitive from a C API primit...
Definition: dnnl.hpp:6424
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const logsoftmax_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a logsoftmax backward propagation primitive.
Definition: dnnl.hpp:6411
primitive_desc(const desc &adesc, const engine &aengine, const logsoftmax_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a logsoftmax backward propagation primitive.
Definition: dnnl.hpp:6391
Logsoftmax backward propagation primitive.
Definition: dnnl.hpp:6348
logsoftmax_backward(const primitive_desc &pd)
Constructs a logsoftmax backward propagation primitive.
Definition: dnnl.hpp:6448
logsoftmax_backward()=default
Default constructor. Produces an empty object.
Descriptor for a logsoftmax forward propagation primitive.
Definition: dnnl.hpp:6256
desc(prop_kind aprop_kind, const memory::desc &data_desc, int logsoftmax_axis)
Constructs a descriptor for a logsoftmax forward propagation primitive.
Definition: dnnl.hpp:6270
desc()=default
Default constructor. Produces an empty object.
Primitive descriptor for a logsoftmax forward propagation primitive.
Definition: dnnl.hpp:6281
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:6335
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:6332
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a logsoftmax forward propagation primitive.
Definition: dnnl.hpp:6311
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a logsoftmax forward propagation primitive from a C API primiti...
Definition: dnnl.hpp:6322
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a logsoftmax forward propagation primitive.
Definition: dnnl.hpp:6295
primitive_desc()=default
Default constructor. Produces an empty object.
Logsoftmax forward propagation primitive.
Definition: dnnl.hpp:6254
logsoftmax_forward()=default
Default constructor. Produces an empty object.
logsoftmax_forward(const primitive_desc &pd)
Constructs a logsoftmax forward propagation primitive.
Definition: dnnl.hpp:6344
Descriptor for an LRN backward propagation primitive.
Definition: dnnl.hpp:5486
desc(algorithm aalgorithm, const memory::desc &data_desc, const memory::desc &diff_data_desc, memory::dim local_size, float alpha, float beta, float k=1.f)
Constructs a descriptor for an LRN backward propagation primitive.
Definition: dnnl.hpp:5501
Primitive descriptor for an LRN backward propagation primitive.
Definition: dnnl.hpp:5514
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const lrn_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an LRN backward propagation primitive.
Definition: dnnl.hpp:5549
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for an LRN backward propagation primitive from a C API primitive de...
Definition: dnnl.hpp:5562
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:5570
primitive_desc(const desc &adesc, const engine &aengine, const lrn_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an LRN backward propagation primitive.
Definition: dnnl.hpp:5530
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:5573
memory::desc diff_src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:5567
primitive_desc()=default
Default constructor. Produces an empty object.
Local response normalization (LRN) backward propagation primitive.
Definition: dnnl.hpp:5484
lrn_backward(const primitive_desc &pd)
Constructs an LRN backward propagation primitive.
Definition: dnnl.hpp:5582
lrn_backward()=default
Default constructor. Produces an empty object.
Descriptor for an LRN forward propagation primitive.
Definition: dnnl.hpp:5391
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &data_desc, memory::dim local_size, float alpha, float beta, float k=1.f)
Constructs a descriptor for a LRN forward propagation primitive.
Definition: dnnl.hpp:5407
Primitive descriptor for an LRN forward propagation primitive.
Definition: dnnl.hpp:5420
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:5465
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:5468
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for an LRN forward propagation primitive.
Definition: dnnl.hpp:5448
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:5471
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for an LRN forward propagation primitive.
Definition: dnnl.hpp:5433
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for an LRN forward propagation primitive from a C API primitive des...
Definition: dnnl.hpp:5459
Local response normalization (LRN) forward propagation primitive.
Definition: dnnl.hpp:5389
lrn_forward()=default
Default constructor. Produces an empty object.
lrn_forward(const primitive_desc &pd)
Constructs an LRN forward propagation primitive.
Definition: dnnl.hpp:5480
Descriptor for an LSTM backward propagation primitive.
Definition: dnnl.hpp:8348
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &src_iter_c_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &weights_peephole_desc, const memory::desc &weights_projection_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &dst_iter_c_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_src_iter_c_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_weights_peephole_desc, const memory::desc &diff_weights_projection_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc, const memory::desc &diff_dst_iter_c_desc, rnn_flags flags=rnn_flags::undef)
Constructs an LSTM (with or without peephole and with or without projection) descriptor for backward ...
Definition: dnnl.hpp:8426
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &src_iter_c_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &dst_iter_c_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_src_iter_c_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc, const memory::desc &diff_dst_iter_c_desc, rnn_flags flags=rnn_flags::undef)
Constructs an LSTM descriptor for backward propagation using prop_kind, direction,...
Definition: dnnl.hpp:8637
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &src_iter_c_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &weights_peephole_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &dst_iter_c_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_src_iter_c_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_weights_peephole_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc, const memory::desc &diff_dst_iter_c_desc, rnn_flags flags=rnn_flags::undef)
Constructs an LSTM (with or without peephole) descriptor for backward propagation using prop_kind,...
Definition: dnnl.hpp:8538
Primitive descriptor for an LSTM backward propagation primitive.
Definition: dnnl.hpp:8678
memory::desc weights_iter_desc() const
Returns weights iteration memory descriptor.
Definition: dnnl.hpp:8749
memory::desc diff_dst_iter_desc() const
Returns diff destination iteration memory descriptor.
Definition: dnnl.hpp:8830
memory::desc diff_weights_projection_desc() const
Returns diff weights projection memory descriptor.
Definition: dnnl.hpp:8815
memory::desc weights_peephole_desc() const
Returns weights peephole memory descriptor.
Definition: dnnl.hpp:8754
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for an LSTM backward propagation primitive from a C API primitive d...
Definition: dnnl.hpp:8726
memory::desc diff_weights_peephole_desc() const
Returns diff weights peephole memory descriptor.
Definition: dnnl.hpp:8810
memory::desc dst_iter_c_desc() const
Returns source iteration memory descriptor.
Definition: dnnl.hpp:8775
memory::desc src_layer_desc() const
Returns source layer memory descriptor.
Definition: dnnl.hpp:8731
memory::desc dst_iter_desc() const
Returns destination iteration memory descriptor.
Definition: dnnl.hpp:8772
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &adesc, const engine &aengine, const lstm_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an LSTM backward propagation primitive.
Definition: dnnl.hpp:8694
memory::desc diff_src_layer_desc() const
Returns diff source layer memory descriptor.
Definition: dnnl.hpp:8785
memory::desc src_iter_desc() const
Returns source iteration memory descriptor.
Definition: dnnl.hpp:8736
memory::desc diff_weights_iter_desc() const
Returns diff weights iteration memory descriptor.
Definition: dnnl.hpp:8805
memory::desc weights_projection_desc() const
Returns weights projection memory descriptor.
Definition: dnnl.hpp:8759
memory::desc diff_bias_desc() const
Returns diff bias memory descriptor.
Definition: dnnl.hpp:8820
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const lstm_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for an LSTM backward propagation primitive.
Definition: dnnl.hpp:8713
memory::desc bias_desc() const
Returns bias memory descriptor.
Definition: dnnl.hpp:8764
memory::desc src_iter_c_desc() const
Returns source iteration memory descriptor.
Definition: dnnl.hpp:8739
memory::desc dst_layer_desc() const
Returns destination layer memory descriptor.
Definition: dnnl.hpp:8767
memory::desc diff_dst_iter_c_desc() const
Returns diff destination recurrent cell state memory descriptor.
Definition: dnnl.hpp:8835
memory::desc diff_src_iter_desc() const
Returns diff source iteration memory descriptor.
Definition: dnnl.hpp:8790
memory::desc diff_dst_layer_desc() const
Returns diff destination layer memory descriptor.
Definition: dnnl.hpp:8825
memory::desc weights_layer_desc() const
Returns weights layer memory descriptor.
Definition: dnnl.hpp:8744
memory::desc diff_weights_layer_desc() const
Returns diff weights layer memory descriptor.
Definition: dnnl.hpp:8800
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:8780
memory::desc diff_src_iter_c_desc() const
Returns diff source recurrent cell state memory descriptor.
Definition: dnnl.hpp:8795
LSTM backward propagation primitive.
Definition: dnnl.hpp:8346
lstm_backward()=default
Default constructor. Produces an empty object.
lstm_backward(const primitive_desc &pd)
Constructs an LSTM backward propagation primitive.
Definition: dnnl.hpp:8846
Descriptor for an LSTM forward propagation primitive.
Definition: dnnl.hpp:8031
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &src_iter_c_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &dst_iter_c_desc, rnn_flags flags=rnn_flags::undef)
Constructs a descriptor for an LSTM forward propagation primitive.
Definition: dnnl.hpp:8211
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &src_iter_c_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &weights_peephole_desc, const memory::desc &weights_projection_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &dst_iter_c_desc, rnn_flags flags=rnn_flags::undef)
Constructs a descriptor for an LSTM (with or without peephole and with or without projection) forward...
Definition: dnnl.hpp:8082
desc(prop_kind aprop_kind, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &src_iter_c_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &weights_peephole_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &dst_iter_c_desc, rnn_flags flags=rnn_flags::undef)
Constructs a descriptor for an LSTM (with or without peephole) forward propagation primitive.
Definition: dnnl.hpp:8150
Primitive descriptor for an LSTM forward propagation primitive.
Definition: dnnl.hpp:8237
memory::desc dst_iter_desc() const
Returns destination iteration memory descriptor.
Definition: dnnl.hpp:8323
memory::desc weights_peephole_desc() const
Returns weights peephole memory descriptor.
Definition: dnnl.hpp:8305
memory::desc weights_iter_desc() const
Returns weights iteration memory descriptor.
Definition: dnnl.hpp:8300
memory::desc dst_layer_desc() const
Returns destination layer memory descriptor.
Definition: dnnl.hpp:8318
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:8331
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for an LSTM forward propagation primitive.
Definition: dnnl.hpp:8250
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for an LSTM forward propagation primitive from a C API primitive de...
Definition: dnnl.hpp:8276
memory::desc dst_iter_c_desc() const
Returns source iteration memory descriptor.
Definition: dnnl.hpp:8326
memory::desc weights_layer_desc() const
Returns weights layer memory descriptor.
Definition: dnnl.hpp:8295
memory::desc weights_projection_desc() const
Returns weights projection memory descriptor.
Definition: dnnl.hpp:8310
memory::desc src_iter_c_desc() const
Returns source iteration memory descriptor.
Definition: dnnl.hpp:8290
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for an LSTM forward propagation primitive.
Definition: dnnl.hpp:8265
memory::desc src_iter_desc() const
Returns source iteration memory descriptor.
Definition: dnnl.hpp:8287
memory::desc bias_desc() const
Returns bias memory descriptor.
Definition: dnnl.hpp:8315
memory::desc src_layer_desc() const
Returns source layer memory descriptor.
Definition: dnnl.hpp:8282
LSTM forward propagation primitive.
Definition: dnnl.hpp:8029
lstm_forward(const primitive_desc &pd)
Constructs an LSTM forward propagation primitive.
Definition: dnnl.hpp:8342
lstm_forward()=default
Default constructor. Produces an empty object.
Descriptor for a matmul primitive.
Definition: dnnl.hpp:9856
desc(const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &dst_desc)
Constructs a descriptor for a matmul primitive.
Definition: dnnl.hpp:9864
desc(const memory::desc &src_desc, const memory::desc &weights_desc, const memory::desc &bias_desc, const memory::desc &dst_desc)
Constructs a descriptor for a matmul primitive.
Definition: dnnl.hpp:9878
Primitive descriptor for a matmul primitive.
Definition: dnnl.hpp:9888
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a matmul primitive.
Definition: dnnl.hpp:9914
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition: dnnl.hpp:9930
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a matmul primitive from a C API primitive descriptor that must ...
Definition: dnnl.hpp:9923
memory::desc bias_desc() const
Returns the bias memory descriptor.
Definition: dnnl.hpp:9935
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:9927
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:9940
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a matmul primitive.
Definition: dnnl.hpp:9900
Matrix multiplication (matmul) primitive.
Definition: dnnl.hpp:9854
matmul(const primitive_desc &pd)
Constructs a matmul primitive.
Definition: dnnl.hpp:9948
matmul()=default
Default constructor. Produces an empty object.
A memory descriptor.
Definition: dnnl.hpp:1984
desc(const dims &adims, data_type adata_type, format_tag aformat_tag, bool allow_empty=false)
Constructs a memory descriptor.
Definition: dnnl.hpp:2008
desc()
Constructs a zero (empty) memory descriptor.
Definition: dnnl.hpp:1991
bool operator!=(const desc &other) const
An inequality operator.
Definition: dnnl.hpp:2219
desc permute_axes(const std::vector< int > &permutation, bool allow_empty=false) const
Constructs a memory descriptor by permuting axes in an existing one.
Definition: dnnl.hpp:2170
desc submemory_desc(const dims &adims, const dims &offsets, bool allow_empty=false) const
Constructs a memory descriptor for a region inside an area described by this memory descriptor.
Definition: dnnl.hpp:2066
bool operator==(const desc &other) const
An equality operator.
Definition: dnnl.hpp:2211
bool is_zero() const
Checks whether the memory descriptor is zero (empty).
Definition: dnnl.hpp:2205
memory::dims dims() const
Returns dimensions of the memory descriptor.
Definition: dnnl.hpp:2186
memory::data_type data_type() const
Returns the data type of the memory descriptor.
Definition: dnnl.hpp:2192
desc reshape(const dims &adims, bool allow_empty=false) const
Constructs a memory descriptor by reshaping an existing one.
Definition: dnnl.hpp:2122
desc(const dims &adims, data_type adata_type, const dims &strides, bool allow_empty=false)
Constructs a memory descriptor by strides.
Definition: dnnl.hpp:2036
size_t get_size() const
Returns size of the memory descriptor in bytes.
Definition: dnnl.hpp:2200
desc(const dnnl_memory_desc_t &data)
Constructs a memory descriptor from a C API data structure.
Definition: dnnl.hpp:2053
dnnl_memory_desc_t data
The underlying C API data structure.
Definition: dnnl.hpp:1987
Memory object.
Definition: dnnl.hpp:1108
void unmap_data(void *mapped_ptr) const
Unmaps a memory object and writes back any changes made to the previously mapped memory buffer.
Definition: dnnl.hpp:2385
T * map_data() const
Maps a memory object and returns a host-side pointer to a memory buffer with a copy of its contents.
Definition: dnnl.hpp:2368
static void validate_dims(const std::vector< T > &v, int min_size=0)
Helper function that validates that an std::vector of dimensions can be safely converted to the C API...
Definition: dnnl.hpp:1124
memory()=default
Default constructor.
dnnl_dim_t dim
Integer type for representing dimension sizes and indices.
Definition: dnnl.hpp:1112
memory(const desc &md, const engine &aengine, void *handle)
Constructs a memory object.
Definition: dnnl.hpp:2252
void set_data_handle(void *handle, const stream &astream) const
Sets the underlying memory buffer.
Definition: dnnl.hpp:2324
void * get_data_handle() const
Returns the underlying memory buffer.
Definition: dnnl.hpp:2289
format_tag
Memory format tag specification.
Definition: dnnl.hpp:1205
data_type
Data type specification.
Definition: dnnl.hpp:1130
@ undef
Undefined data type (used for empty memory descriptors).
engine get_engine() const
Returns the associated engine.
Definition: dnnl.hpp:2278
format_kind
Memory format kind.
Definition: dnnl.hpp:1149
memory(const desc &md, const engine &aengine)
Constructs a memory object.
Definition: dnnl.hpp:2266
void set_data_handle(void *handle) const
Sets the underlying memory buffer.
Definition: dnnl.hpp:2340
desc get_desc() const
Returns the associated memory descriptor.
Definition: dnnl.hpp:2270
std::vector< dim > dims
Vector of dimensions.
Definition: dnnl.hpp:1115
Descriptor for a pooling backward propagation primitive.
Definition: dnnl.hpp:5710
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &kernel, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for pooling backward propagation primitive.
Definition: dnnl.hpp:5734
Primitive descriptor for a pooling backward propagation primitive.
Definition: dnnl.hpp:5753
primitive_desc(const desc &adesc, const engine &aengine, const pooling_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a pooling backward propagation primitive.
Definition: dnnl.hpp:5769
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:5809
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:5812
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const pooling_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a pooling backward propagation primitive.
Definition: dnnl.hpp:5788
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc diff_src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:5806
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a pooling backward propagation primitive from a C API primitive...
Definition: dnnl.hpp:5801
Pooling backward propagation primitive.
Definition: dnnl.hpp:5708
pooling_backward()=default
Default constructor. Produces an empty object.
pooling_backward(const primitive_desc &pd)
Constructs a pooling backward propagation primitive.
Definition: dnnl.hpp:5821
Descriptor for a pooling forward propagation primitive.
Definition: dnnl.hpp:5598
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &kernel, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for pooling forward propagation primitive.
Definition: dnnl.hpp:5625
Primitive descriptor for a pooling forward propagation primitive.
Definition: dnnl.hpp:5644
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a pooling forward propagation primitive.
Definition: dnnl.hpp:5672
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:5692
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:5689
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a pooling forward propagation primitive from a C API primitive ...
Definition: dnnl.hpp:5683
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:5695
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a pooling forward propagation primitive.
Definition: dnnl.hpp:5657
Pooling forward propagation primitive.
Definition: dnnl.hpp:5596
pooling_forward(const primitive_desc &pd)
Constructs a pooling forward propagation primitive.
Definition: dnnl.hpp:5704
pooling_forward()=default
Default constructor. Produces an empty object.
Descriptor for a pooling backward propagation primitive.
Definition: dnnl.hpp:10355
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &diff_dst_desc, const memory::dims &strides, const memory::dims &kernel, const memory::dims &dilation, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for pooling v2 (dilated pooling) backward propagation primitive.
Definition: dnnl.hpp:10381
Primitive descriptor for a pooling v2 (dilated pooling) backward propagation primitive.
Definition: dnnl.hpp:10402
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:10461
memory::desc diff_src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:10458
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a pooling v2 (dilated pooling) backward propagation primitive f...
Definition: dnnl.hpp:10453
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &adesc, const engine &aengine, const pooling_v2_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a pooling v2 (dilated pooling) backward propagation primitive.
Definition: dnnl.hpp:10419
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const pooling_v2_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a pooling v2 (dilated pooling) backward propagation primitive.
Definition: dnnl.hpp:10439
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:10464
Pooling v2 (dilated pooling) backward propagation primitive.
Definition: dnnl.hpp:10353
pooling_v2_backward(const primitive_desc &pd)
Constructs a pooling v2 (dilated pooling) backward propagation primitive.
Definition: dnnl.hpp:10474
pooling_v2_backward()=default
Default constructor. Produces an empty object.
Descriptor for a pooling forward propagation primitive.
Definition: dnnl.hpp:10234
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &dst_desc, const memory::dims &strides, const memory::dims &kernel, const memory::dims &dilation, const memory::dims &padding_l, const memory::dims &padding_r)
Constructs a descriptor for pooling v2 (dilated pooling) forward propagation primitive.
Definition: dnnl.hpp:10263
Primitive descriptor for a pooling forward propagation primitive.
Definition: dnnl.hpp:10285
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:10339
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:10336
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:10333
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a pooling v2 (dilated pooling) forward propagation primitive.
Definition: dnnl.hpp:10315
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a pooling v2 (dilated pooling) forward propagation primitive fr...
Definition: dnnl.hpp:10327
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a pooling v2 (dilated pooling) forward propagation primitive.
Definition: dnnl.hpp:10299
Pooling v2 (dilated pooling) forward propagation primitive.
Definition: dnnl.hpp:10232
pooling_v2_forward()=default
Default constructor. Produces an empty object.
pooling_v2_forward(const primitive_desc &pd)
Constructs a pooling v2 (dilated pooling) forward propagation primitive.
Definition: dnnl.hpp:10349
Post-ops.
Definition: dnnl.hpp:2450
void get_params_dw_k3s1p1(int index, memory::data_type &weights_data_type, memory::data_type &bias_data_type, memory::data_type &dst_data_type, int &mask, std::vector< float > &scales) const
Returns the parameters of an depthwise post-op with stride 1.
Definition: dnnl.hpp:2626
void get_params_binary(int index, algorithm &aalgorithm, memory::desc &src1_desc) const
Returns the parameters of a binary post-op.
Definition: dnnl.hpp:2762
void get_params_sum(int index, float &scale, memory::data_type &data_type) const
Returns the parameters of an accumulation (sum) post-op.
Definition: dnnl.hpp:2527
void append_eltwise(float scale, algorithm aalgorithm, float alpha, float beta)
Appends an elementwise post-op.
Definition: dnnl.hpp:2549
void append_binary(algorithm aalgorithm, const memory::desc &src1_desc)
Appends a binary post-op.
Definition: dnnl.hpp:2751
void append_dw_k3s1p1(memory::data_type weights_data_type, memory::data_type bias_data_type, memory::data_type dst_data_type, int mask, const std::vector< float > &scales)
Appends a depthwise post-op convolution with stride 1.
Definition: dnnl.hpp:2600
primitive::kind kind(int index) const
Returns the primitive kind of post-op at entry with a certain index.
Definition: dnnl.hpp:2467
int len() const
Returns the number of post-ops entries.
Definition: dnnl.hpp:2462
void append_dw_k3s2p1(memory::data_type weights_data_type, memory::data_type bias_data_type, memory::data_type dst_data_type, int mask, const std::vector< float > &scales)
Appends a depthwise post-op convolution with stride 2.
Definition: dnnl.hpp:2685
post_ops()
Constructs an empty sequence of post-ops.
Definition: dnnl.hpp:2454
void get_params_dw_k3s2p1(int index, memory::data_type &weights_data_type, memory::data_type &bias_data_type, memory::data_type &dst_data_type, int &mask, std::vector< float > &scales) const
Returns the parameters of an depthwise post-op with stride 2.
Definition: dnnl.hpp:2711
void get_params_eltwise(int index, float &scale, algorithm &aalgorithm, float &alpha, float &beta) const
Returns parameters of an elementwise post-op.
Definition: dnnl.hpp:2563
void get_params_sum(int index, float &scale) const
Returns the parameters of an accumulation (sum) post-op.
Definition: dnnl.hpp:2517
void append_sum(float scale=1.f, memory::data_type data_type=memory::data_type::undef)
Appends an accumulation (sum) post-op.
Definition: dnnl.hpp:2502
Descriptor for a PReLU backward propagation primitive.
Definition: dnnl.hpp:10578
desc(const memory::desc &data_desc, const memory::desc &weight_desc, const memory::desc &diff_data_desc, const memory::desc &diff_weights_desc)
Constructs a descriptor for a PReLU backward propagation primitive.
Definition: dnnl.hpp:10589
Primitive descriptor for prelu backward propagation.
Definition: dnnl.hpp:10602
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:10657
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition: dnnl.hpp:10660
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const prelu_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a PReLU backward propagation primitive.
Definition: dnnl.hpp:10639
primitive_desc(const desc &adesc, const engine &aengine, const prelu_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a PReLU backward propagation primitive.
Definition: dnnl.hpp:10619
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:10663
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a prelu backward propagation primitive from a C API primitive d...
Definition: dnnl.hpp:10652
primitive_desc()=default
Default constructor. Produces an empty object.
PReLU backward propagation primitive.
Definition: dnnl.hpp:10576
prelu_backward()=default
Default constructor. Produces an empty object.
prelu_backward(const primitive_desc &pd)
Constructs a prelu backward propagation primitive.
Definition: dnnl.hpp:10672
Descriptor for a PReLU forward propagation primitive.
Definition: dnnl.hpp:10491
desc(prop_kind aprop_kind, const memory::desc &data_desc, const memory::desc &weight_desc)
Constructs a descriptor for a PReLU forward propagation primitive.
Definition: dnnl.hpp:10502
Primitive descriptor for a PReLU forward propagation primitive.
Definition: dnnl.hpp:10513
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:10563
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:10560
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a PReLU forward propagation primitive.
Definition: dnnl.hpp:10527
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a PReLU forward propagation primitive.
Definition: dnnl.hpp:10543
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a prelu forward propagation primitive from a C API primitive de...
Definition: dnnl.hpp:10554
PReLU forward propagation primitive.
Definition: dnnl.hpp:10489
prelu_forward(const primitive_desc &pd)
Constructs a prelu forward propagation primitive.
Definition: dnnl.hpp:10572
prelu_forward()=default
Default constructor. Produces an empty object.
Primitive attributes.
Definition: dnnl.hpp:2786
void get_zero_points(int arg, int &mask, std::vector< int32_t > &zero_points) const
Returns zero points correspondence mask and values.
Definition: dnnl.hpp:2953
const post_ops get_post_ops() const
Returns post-ops previously set via set_post_ops().
Definition: dnnl.hpp:2999
void set_rnn_data_qparams(float scale, float shift)
Sets quantization scale and shift parameters for RNN data tensors.
Definition: dnnl.hpp:3054
void get_rnn_weights_qparams(int &mask, std::vector< float > &scales)
Returns the quantization scaling factors for RNN projection weights tensors.
Definition: dnnl.hpp:3132
void get_rnn_data_qparams(float &scale, float &shift)
Returns the quantization scale and shift parameters for RNN data tensors.
Definition: dnnl.hpp:3070
void set_output_scales(int mask, const std::vector< float > &scales)
Sets output scaling factors correspondence mask and values.
Definition: dnnl.hpp:2888
void get_rnn_weights_projection_qparams(int &mask, std::vector< float > &scales)
Returns the quantization scaling factors for RNN projection weights tensors.
Definition: dnnl.hpp:3201
void set_rnn_weights_qparams(int mask, const std::vector< float > &scales)
Sets quantization scaling factors for RNN weights tensors.
Definition: dnnl.hpp:3106
void set_rnn_weights_projection_qparams(int mask, const std::vector< float > &scales)
Sets quantization scaling factors for RNN projection weights tensors.
Definition: dnnl.hpp:3173
void set_scratchpad_mode(scratchpad_mode mode)
Sets scratchpad mode.
Definition: dnnl.hpp:2817
void set_scales(int arg, int mask, const std::vector< float > &scales)
Sets scaling factors for primitive operations for a given memory argument.
Definition: dnnl.hpp:2936
void get_scales(int arg, int &mask, std::vector< float > &scales) const
Returns scaling factors correspondence mask and values for a given memory argument.
Definition: dnnl.hpp:2906
void get_output_scales(int &mask, std::vector< float > &scales) const
Returns output scaling factors correspondence mask and values.
Definition: dnnl.hpp:2832
primitive_attr(dnnl_primitive_attr_t attr)
Creates primitive attributes from a C API dnnl_primitive_attr_t handle.
Definition: dnnl.hpp:2802
void set_post_ops(const post_ops ops)
Sets post-ops.
Definition: dnnl.hpp:3016
primitive_attr()
Constructs default (empty) primitive attributes.
Definition: dnnl.hpp:2790
void set_zero_points(int arg, int mask, const std::vector< int32_t > &zero_points)
Sets zero points for primitive operations for a given memory argument.
Definition: dnnl.hpp:2988
scratchpad_mode get_scratchpad_mode() const
Returns the scratchpad mode.
Definition: dnnl.hpp:2806
Base class for all primitive descriptors.
Definition: dnnl.hpp:3225
primitive_attr get_primitive_attr() const
Returns the primitive attributes.
Definition: dnnl.hpp:3409
memory::desc diff_weights_desc(int idx) const
Returns a diff weights memory descriptor.
Definition: dnnl.hpp:3335
primitive_desc_base()=default
Default constructor. Produces an empty object.
engine get_engine() const
Returns the engine of the primitive descriptor.
Definition: dnnl.hpp:3233
memory::desc query_md(query what, int idx=0) const
Returns a memory descriptor.
Definition: dnnl.hpp:3270
memory::desc dst_desc(int idx) const
Returns a destination memory descriptor.
Definition: dnnl.hpp:3299
memory::desc diff_dst_desc(int idx) const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:3326
memory::desc scratchpad_desc() const
Returns the scratchpad memory descriptor.
Definition: dnnl.hpp:3391
void reset_with_clone(const_dnnl_primitive_desc_t pd)
Resets the value of the handle to a clone of a C API primitive descriptor.
Definition: dnnl.hpp:3433
dnnl::primitive::kind get_kind() const
Returns the kind of the primitive descriptor.
Definition: dnnl.hpp:3421
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:3370
memory::desc diff_src_desc(int idx) const
Returns a diff source memory descriptor.
Definition: dnnl.hpp:3317
memory::desc weights_desc() const
Returns a weights memory descriptor.
Definition: dnnl.hpp:3358
primitive_desc_base(dnnl_primitive_desc_t pd, dnnl::primitive::kind prim_kind, dnnl::prop_kind prop_kind1, dnnl::prop_kind prop_kind2)
Constructs a primitive descriptor base object from a clone of a C API primitive descriptor after veri...
Definition: dnnl.hpp:3485
primitive_desc_base(dnnl_primitive_desc_t pd, dnnl::primitive::kind prim_kind)
Constructs a primitive descriptor base object from a clone of a C API primitive descriptor after veri...
Definition: dnnl.hpp:3453
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition: dnnl.hpp:3364
memory::desc weights_desc(int idx) const
Returns a weights memory descriptor.
Definition: dnnl.hpp:3308
memory::dim query_s64(query what) const
Returns a memory::dim value (same as int64_t).
Definition: dnnl.hpp:3249
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:3382
engine scratchpad_engine() const
Returns the engine on which the scratchpad memory is located.
Definition: dnnl.hpp:3397
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:3352
const char * impl_info_str() const
Returns implementation name.
Definition: dnnl.hpp:3237
memory::desc src_desc(int idx) const
Returns a source memory descriptor.
Definition: dnnl.hpp:3290
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:3346
primitive_desc_base(dnnl_primitive_desc_t pd, dnnl::primitive::kind prim_kind, dnnl::prop_kind aprop_kind)
Constructs a primitive descriptor base object from a clone of a C API primitive descriptor after veri...
Definition: dnnl.hpp:3468
memory::desc diff_weights_desc() const
Returns a diff weights memory descriptor.
Definition: dnnl.hpp:3376
A base class for descriptors of all primitives that have an operation descriptor and that support ite...
Definition: dnnl.hpp:3879
primitive_desc(const_dnnl_op_desc_t desc, const primitive_attr *attr, const engine &aengine, const_dnnl_primitive_desc_t hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor.
Definition: dnnl.hpp:3906
bool next_impl()
Advances the primitive iterator to the next implementation.
Definition: dnnl.hpp:3924
Base class for all computational primitives.
Definition: dnnl.hpp:269
void execute(const stream &astream, const std::unordered_map< int, memory > &args) const
Executes computations specified by the primitive in a specified stream.
primitive()=default
Default constructor. Constructs an empty object.
primitive(const primitive_desc &pd)
Constructs a primitive from a primitive descriptor.
kind
Kinds of primitives supported by the library.
Definition: dnnl.hpp:271
@ deconvolution
A deconvolution primitive.
@ pooling_v2
A pooling version 2 primitive.
@ inner_product
An inner product primitive.
@ logsoftmax
A logsoftmax primitive.
@ layer_normalization
A layer normalization primitive.
@ pooling
A pooling primitive.
@ resampling
A resampling primitive.
@ shuffle
A shuffle primitive.
@ rnn
An RNN primitive.
@ batch_normalization
A batch normalization primitive.
@ lrn
An LRN primitive.
@ prelu
A PReLU primitive.
@ eltwise
An element-wise primitive.
@ convolution
A convolution primitive.
@ softmax
A softmax primitive.
@ undef
Undefined primitive.
primitive(const_dnnl_primitive_desc_t c_pd)
Constructs a primitive from a C API primitive descriptor.
Descriptor for reduction.
Definition: dnnl.hpp:10689
desc()=default
Default constructor. Produces an empty object.
desc(algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &dst_desc, float p, float eps)
Constructs a descriptor for a reduction primitive using algorithm specific parameters,...
Definition: dnnl.hpp:10712
Primitive descriptor for a reduction primitive.
Definition: dnnl.hpp:10722
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:10761
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:10764
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a reduction primitive from a C API primitive descriptor that mu...
Definition: dnnl.hpp:10757
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a reduction primitive.
Definition: dnnl.hpp:10748
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a reduction primitive.
Definition: dnnl.hpp:10734
Reduction.
Definition: dnnl.hpp:10687
reduction(const primitive_desc &pd)
Constructs a reduction primitive.
Definition: dnnl.hpp:10772
reduction()=default
Default constructor. Produces an empty object.
Primitive descriptor for a reorder primitive.
Definition: dnnl.hpp:3549
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:3634
primitive_desc(const engine &src_engine, const memory::desc &src_md, const engine &dst_engine, const memory::desc &dst_md, const primitive_attr &attr=primitive_attr(), bool allow_empty=false)
Constructs a primitive descriptor for reorder primitive.
Definition: dnnl.hpp:3572
primitive_desc(const memory &src, const memory &dst, const primitive_attr &attr=primitive_attr(), bool allow_empty=false)
Constructs a primitive descriptor for reorder primitive.
Definition: dnnl.hpp:3598
engine get_src_engine() const
Returns the engine on which the source memory is allocated.
Definition: dnnl.hpp:3623
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for reorder primitive from a C API primitive descriptor which must ...
Definition: dnnl.hpp:3618
engine get_dst_engine() const
Returns the engine on which the destination memory is allocated.
Definition: dnnl.hpp:3629
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:3637
Reorder primitive.
Definition: dnnl.hpp:3547
reorder(const primitive_desc &pd)
Constructs a reorder primitive.
Definition: dnnl.hpp:3645
void execute(const stream &astream, memory &src, memory &dst) const
Executes the reorder primitive.
Definition: dnnl.hpp:3666
reorder()=default
Default constructor. Produces an empty object.
reorder(const memory &src, const memory &dst, const primitive_attr &attr=primitive_attr())
Constructs a reorder primitive that would reorder data between memory objects having the same memory ...
Definition: dnnl.hpp:3654
Descriptor for a resampling backward propagation primitive.
Definition: dnnl.hpp:10110
desc(algorithm aalgorithm, const memory::desc &diff_src_desc, const memory::desc &diff_dst_desc)
Constructs a descriptor for a resampling backward propagation primitive using source and destination ...
Definition: dnnl.hpp:10121
desc(algorithm aalgorithm, const std::vector< float > &factors, const memory::desc &diff_src_desc, const memory::desc &diff_dst_desc)
Constructs a descriptor for resampling backward propagation primitive.
Definition: dnnl.hpp:10138
Primitive descriptor for resampling backward propagation primitive.
Definition: dnnl.hpp:10151
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a resampling backward propagation primitive from a C API primit...
Definition: dnnl.hpp:10201
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition: dnnl.hpp:10206
primitive_desc(const desc &adesc, const engine &aengine, const resampling_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a resampling backward propagation primitive.
Definition: dnnl.hpp:10168
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:10209
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const resampling_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a resampling backward propagation primitive.
Definition: dnnl.hpp:10188
Resampling backward propagation primitive.
Definition: dnnl.hpp:10108
resampling_backward(const primitive_desc &pd)
Constructs a resampling backward propagation primitive.
Definition: dnnl.hpp:10218
resampling_backward()=default
Default constructor. Produces an empty object.
Descriptor for resampling forward propagation.
Definition: dnnl.hpp:9966
desc(prop_kind aprop_kind, algorithm aalgorithm, const memory::desc &src_desc, const memory::desc &dst_desc)
Constructs a descriptor for a resampling forward propagation primitive using source and destination m...
Definition: dnnl.hpp:9984
desc(prop_kind aprop_kind, algorithm aalgorithm, const std::vector< float > &factors, const memory::desc &src_desc)
Constructs a descriptor for a resampling forward propagation primitive using source memory descriptor...
Definition: dnnl.hpp:10004
desc(prop_kind aprop_kind, algorithm aalgorithm, const std::vector< float > &factors, const memory::desc &src_desc, const memory::desc &dst_desc)
Constructs a descriptor for a resampling forward propagation primitive.
Definition: dnnl.hpp:10031
Primitive descriptor for a resampling forward propagation primitive.
Definition: dnnl.hpp:10045
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a resampling forward propagation primitive.
Definition: dnnl.hpp:10059
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:10095
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:10092
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a resampling forward propagation primitive from a C API primiti...
Definition: dnnl.hpp:10086
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a resampling forward propagation primitive.
Definition: dnnl.hpp:10075
primitive_desc()=default
Default constructor. Produces an empty object.
Resampling forward propagation.
Definition: dnnl.hpp:9964
resampling_forward()=default
Default constructor. Produces an empty object.
resampling_forward(const primitive_desc &pd)
Constructs a resampling forward propagation primitive.
Definition: dnnl.hpp:10104
Base class for primitive descriptors for RNN primitives.
Definition: dnnl.hpp:7434
memory::desc dst_iter_c_desc() const
Returns destination recurrent cell state memory descriptor.
Definition: dnnl.hpp:7519
memory::desc weights_peephole_desc() const
Returns weights peephole memory descriptor.
Definition: dnnl.hpp:7485
memory::desc diff_weights_layer_desc() const
Returns diff weights layer memory descriptor.
Definition: dnnl.hpp:7545
memory::desc weights_layer_desc() const
Returns weights layer memory descriptor.
Definition: dnnl.hpp:7473
memory::desc weights_iter_desc() const
Returns weights iteration memory descriptor.
Definition: dnnl.hpp:7479
memory::desc diff_src_iter_desc() const
Returns diff source iteration memory descriptor.
Definition: dnnl.hpp:7533
memory::desc diff_dst_iter_c_desc() const
Returns diff destination recurrent cell state memory descriptor.
Definition: dnnl.hpp:7593
memory::desc diff_weights_iter_desc() const
Returns diff weights iteration memory descriptor.
Definition: dnnl.hpp:7551
memory::desc diff_dst_iter_desc() const
Returns diff destination iteration memory descriptor.
Definition: dnnl.hpp:7587
rnn_primitive_desc_base()=default
Default constructor. Produces an empty object.
memory::desc diff_src_iter_c_desc() const
Returns diff source recurrent cell state memory descriptor.
Definition: dnnl.hpp:7539
rnn_primitive_desc_base(dnnl_primitive_desc_t pd, dnnl::prop_kind aprop_kind, dnnl::algorithm cell_kind)
Constructs an RNN primitive descriptor base from a C API primitive descriptor while checking that it ...
Definition: dnnl.hpp:7447
memory::desc diff_bias_desc() const
Returns diff bias memory descriptor.
Definition: dnnl.hpp:7573
memory::desc dst_layer_desc() const
Returns destination layer memory descriptor.
Definition: dnnl.hpp:7505
memory::desc diff_weights_projection_desc() const
Returns diff weights projection memory descriptor.
Definition: dnnl.hpp:7564
memory::desc src_iter_c_desc() const
Returns source recurrent cell state memory descriptor.
Definition: dnnl.hpp:7467
memory::desc src_iter_desc() const
Returns source iteration memory descriptor.
Definition: dnnl.hpp:7461
memory::desc bias_desc() const
Returns bias memory descriptor.
Definition: dnnl.hpp:7499
memory::desc weights_projection_desc() const
Returns weights projection memory descriptor.
Definition: dnnl.hpp:7491
memory::desc src_layer_desc() const
Returns source layer memory descriptor.
Definition: dnnl.hpp:7453
memory::desc diff_dst_layer_desc() const
Returns diff destination layer memory descriptor.
Definition: dnnl.hpp:7579
memory::desc dst_iter_desc() const
Returns destination iteration memory descriptor.
Definition: dnnl.hpp:7513
memory::desc diff_weights_peephole_desc() const
Returns diff weights peephole memory descriptor.
Definition: dnnl.hpp:7557
memory::desc diff_src_layer_desc() const
Returns diff source layer memory descriptor.
Definition: dnnl.hpp:7525
Descriptor for a shuffle primitive backward propagation primitive.
Definition: dnnl.hpp:9671
desc(const memory::desc &diff_data_desc, int axis, int group_size)
Constructs a descriptor for a shuffle backward propagation primitive.
Definition: dnnl.hpp:9681
Primitive descriptor for a shuffle backward propagation primitive.
Definition: dnnl.hpp:9690
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a shuffle backward propagation primitive from a C API primitive...
Definition: dnnl.hpp:9721
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition: dnnl.hpp:9726
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &adesc, const engine &aengine, const shuffle_forward::primitive_desc &hint_fwd_pd, const primitive_attr &attr=primitive_attr(), bool allow_empty=false)
Constructs a primitive descriptor for a shuffle backward propagation primitive.
Definition: dnnl.hpp:9708
memory::desc diff_dst_desc() const
Returns a diff destination memory descriptor.
Definition: dnnl.hpp:9729
Shuffle backward propagation primitive.
Definition: dnnl.hpp:9668
shuffle_backward()=default
Default constructor. Produces an empty object.
shuffle_backward(const primitive_desc &pd)
Constructs a shuffle backward propagation primitive.
Definition: dnnl.hpp:9738
Descriptor for a shuffle forward propagation primitive.
Definition: dnnl.hpp:9596
desc(prop_kind aprop_kind, const memory::desc &data_desc, int axis, int group_size)
Constructs a descriptor for a shuffle forward propagation primitive.
Definition: dnnl.hpp:9608
Primitive descriptor for a shuffle forward propagation primitive.
Definition: dnnl.hpp:9619
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:9655
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:9652
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a shuffle forward propagation primitive from a C API primitive ...
Definition: dnnl.hpp:9646
primitive_desc(const desc &adesc, const engine &aengine, const primitive_attr &attr=primitive_attr(), bool allow_empty=false)
Constructs a primitive descriptor for a shuffle forward propagation primitive.
Definition: dnnl.hpp:9634
primitive_desc()=default
Default constructor. Produces an empty object.
Shuffle forward propagation primitive.
Definition: dnnl.hpp:9594
shuffle_forward()=default
Default constructor. Produces an empty object.
shuffle_forward(const primitive_desc &pd)
Constructs a shuffle forward propagation primitive.
Definition: dnnl.hpp:9664
Descriptor for a softmax backward propagation primitive.
Definition: dnnl.hpp:6146
desc(const memory::desc &diff_data_desc, const memory::desc &data_desc, int softmax_axis)
Constructs a descriptor for a softmax backward propagation primitive.
Definition: dnnl.hpp:6159
desc()=default
Default constructor. Produces an empty object.
Primitive descriptor for a softmax backward propagation primitive.
Definition: dnnl.hpp:6170
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a softmax backward propagation primitive from a C API primitive...
Definition: dnnl.hpp:6220
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const softmax_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a softmax backward propagation primitive.
Definition: dnnl.hpp:6207
memory::desc diff_dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:6231
memory::desc diff_src_desc() const
Returns a diff source memory descriptor.
Definition: dnnl.hpp:6228
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &adesc, const engine &aengine, const softmax_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a softmax backward propagation primitive.
Definition: dnnl.hpp:6187
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:6225
Softmax backward propagation primitive.
Definition: dnnl.hpp:6144
softmax_backward()=default
Default constructor. Produces an empty object.
softmax_backward(const primitive_desc &pd)
Constructs a softmax backward propagation primitive.
Definition: dnnl.hpp:6240
Descriptor for a softmax forward propagation primitive.
Definition: dnnl.hpp:6056
desc(prop_kind aprop_kind, const memory::desc &data_desc, int softmax_axis)
Constructs a descriptor for a softmax forward propagation primitive.
Definition: dnnl.hpp:6070
desc()=default
Default constructor. Produces an empty object.
Primitive descriptor for a softmax forward propagation primitive.
Definition: dnnl.hpp:6081
memory::desc src_desc() const
Returns a source memory descriptor.
Definition: dnnl.hpp:6128
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a softmax forward propagation primitive.
Definition: dnnl.hpp:6095
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:6131
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a softmax forward propagation primitive from a C API primitive ...
Definition: dnnl.hpp:6122
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a softmax forward propagation primitive.
Definition: dnnl.hpp:6111
primitive_desc()=default
Default constructor. Produces an empty object.
Softmax forward propagation primitive.
Definition: dnnl.hpp:6054
softmax_forward()=default
Default constructor. Produces an empty object.
softmax_forward(const primitive_desc &pd)
Constructs a softmax forward propagation primitive.
Definition: dnnl.hpp:6140
An execution stream.
Definition: dnnl.hpp:985
engine get_engine() const
Returns the associated engine.
Definition: dnnl.hpp:1016
stream & wait()
Waits for all primitives executing in the stream to finish.
Definition: dnnl.hpp:1025
stream(const engine &aengine, flags aflags=flags::default_flags)
Constructs a stream for the specified engine and with behavior controlled by the specified flags.
Definition: dnnl.hpp:1007
flags
Stream flags. Can be combined using the bitwise OR operator.
Definition: dnnl.hpp:989
@ out_of_order
Out-of-order execution.
@ default_flags
Default stream configuration.
@ in_order
In-order execution.
stream()=default
Constructs an empty stream.
Primitive descriptor for a sum primitive.
Definition: dnnl.hpp:3788
memory::desc dst_desc() const
Returns a destination memory descriptor.
Definition: dnnl.hpp:3861
primitive_desc()=default
Default constructor. Produces an empty object.
memory::desc src_desc(int idx=0) const
Returns a source memory descriptor.
Definition: dnnl.hpp:3858
primitive_desc(const memory::desc &dst, const std::vector< float > &scales, const std::vector< memory::desc > &srcs, const engine &aengine, const primitive_attr &attr=primitive_attr())
Constructs a primitive descriptor for a sum primitive.
Definition: dnnl.hpp:3802
primitive_desc(const std::vector< float > &scales, const std::vector< memory::desc > &srcs, const engine &aengine, const primitive_attr &attr=primitive_attr())
Constructs a primitive descriptor for a sum primitive.
Definition: dnnl.hpp:3832
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for sum primitive from a C API primitive descriptor which must have...
Definition: dnnl.hpp:3854
Out-of-place summation (sum) primitive.
Definition: dnnl.hpp:3786
sum()=default
Default constructor. Produces an empty object.
sum(const primitive_desc &pd)
Constructs a sum primitive.
Definition: dnnl.hpp:3869
Descriptor for a vanilla RNN backward propagation primitive.
Definition: dnnl.hpp:7804
desc(prop_kind aprop_kind, algorithm activation, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, const memory::desc &diff_src_layer_desc, const memory::desc &diff_src_iter_desc, const memory::desc &diff_weights_layer_desc, const memory::desc &diff_weights_iter_desc, const memory::desc &diff_bias_desc, const memory::desc &diff_dst_layer_desc, const memory::desc &diff_dst_iter_desc, rnn_flags flags=rnn_flags::undef, float alpha=0.0f, float beta=0.0f)
Constructs a descriptor for a vanilla RNN backward propagation primitive.
Definition: dnnl.hpp:7859
Primitive descriptor for an RNN backward propagation primitive.
Definition: dnnl.hpp:7895
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, const vanilla_rnn_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a vanilla RNN backward propagation primitive.
Definition: dnnl.hpp:7932
memory::desc src_iter_desc() const
Returns source iteration memory descriptor.
Definition: dnnl.hpp:7955
memory::desc diff_dst_layer_desc() const
Returns diff destination layer memory descriptor.
Definition: dnnl.hpp:8009
memory::desc dst_layer_desc() const
Returns destination layer memory descriptor.
Definition: dnnl.hpp:7971
memory::desc diff_src_iter_desc() const
Returns diff source iteration memory descriptor.
Definition: dnnl.hpp:7989
memory::desc diff_weights_iter_desc() const
Returns diff weights iteration memory descriptor.
Definition: dnnl.hpp:7999
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &adesc, const engine &aengine, const vanilla_rnn_forward::primitive_desc &hint_fwd_pd, bool allow_empty=false)
Constructs a primitive descriptor for a vanilla RNN backward propagation primitive.
Definition: dnnl.hpp:7912
memory::desc diff_bias_desc() const
Returns diff bias memory descriptor.
Definition: dnnl.hpp:8004
memory::desc weights_iter_desc() const
Returns weights iteration memory descriptor.
Definition: dnnl.hpp:7963
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a vanilla RNN backward propagation primitive from a C API primi...
Definition: dnnl.hpp:7945
memory::desc weights_layer_desc() const
Returns weights layer memory descriptor.
Definition: dnnl.hpp:7958
memory::desc bias_desc() const
Returns bias memory descriptor.
Definition: dnnl.hpp:7968
memory::desc dst_iter_desc() const
Returns destination iteration memory descriptor.
Definition: dnnl.hpp:7976
memory::desc diff_dst_iter_desc() const
Returns diff destination iteration memory descriptor.
Definition: dnnl.hpp:8014
memory::desc diff_src_layer_desc() const
Returns diff source layer memory descriptor.
Definition: dnnl.hpp:7984
memory::desc src_layer_desc() const
Returns source layer memory descriptor.
Definition: dnnl.hpp:7950
memory::desc diff_weights_layer_desc() const
Returns diff weights layer memory descriptor.
Definition: dnnl.hpp:7994
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:7979
Vanilla RNN backward propagation primitive.
Definition: dnnl.hpp:7802
vanilla_rnn_backward(const primitive_desc &pd)
Constructs a vanilla RNN backward propagation primitive.
Definition: dnnl.hpp:8025
vanilla_rnn_backward()=default
Default constructor. Produces an empty object.
Descriptor for a vanilla RNN forward propagation primitive.
Definition: dnnl.hpp:7643
desc(prop_kind aprop_kind, algorithm activation, rnn_direction direction, const memory::desc &src_layer_desc, const memory::desc &src_iter_desc, const memory::desc &weights_layer_desc, const memory::desc &weights_iter_desc, const memory::desc &bias_desc, const memory::desc &dst_layer_desc, const memory::desc &dst_iter_desc, rnn_flags flags=rnn_flags::undef, float alpha=0.0f, float beta=0.0f)
Constructs a descriptor for a vanilla RNN forward propagation primitive.
Definition: dnnl.hpp:7686
Primitive descriptor for a vanilla RNN forward propagation primitive.
Definition: dnnl.hpp:7711
primitive_desc(const desc &adesc, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a vanilla RNN forward propagation primitive.
Definition: dnnl.hpp:7725
primitive_desc(dnnl_primitive_desc_t pd)
Constructs a primitive descriptor for a vanilla RNN forward propagation primitive from a C API primit...
Definition: dnnl.hpp:7752
memory::desc src_layer_desc() const
Returns source layer memory descriptor.
Definition: dnnl.hpp:7758
memory::desc src_iter_desc() const
Returns source iteration memory descriptor.
Definition: dnnl.hpp:7763
memory::desc weights_iter_desc() const
Returns weights iteration memory descriptor.
Definition: dnnl.hpp:7771
memory::desc weights_layer_desc() const
Returns weights layer memory descriptor.
Definition: dnnl.hpp:7766
memory::desc workspace_desc() const
Returns the workspace memory descriptor.
Definition: dnnl.hpp:7787
memory::desc dst_iter_desc() const
Returns destination iteration memory descriptor.
Definition: dnnl.hpp:7784
primitive_desc()=default
Default constructor. Produces an empty object.
primitive_desc(const desc &adesc, const primitive_attr &attr, const engine &aengine, bool allow_empty=false)
Constructs a primitive descriptor for a vanilla RNN forward propagation primitive.
Definition: dnnl.hpp:7741
memory::desc dst_layer_desc() const
Returns destination layer memory descriptor.
Definition: dnnl.hpp:7779
memory::desc bias_desc() const
Returns bias memory descriptor.
Definition: dnnl.hpp:7776
Vanilla RNN forward propagation primitive.
Definition: dnnl.hpp:7641
vanilla_rnn_forward()=default
Default constructor. Produces an empty object.
vanilla_rnn_forward(const primitive_desc &pd)
Constructs a vanilla RNN forward propagation primitive.
Definition: dnnl.hpp:7798
A descriptor of a Batch Normalization operation.
Definition: dnnl_types.h:1827
A descriptor of a binary operation.
Definition: dnnl_types.h:2035
A descriptor of a convolution operation.
Definition: dnnl_types.h:1534
A descriptor of a element-wise operation.
Definition: dnnl_types.h:1609
An opaque structure to describe an engine.
A descriptor of an inner product operation.
Definition: dnnl_types.h:1897
A descriptor of a Layer Normalization operation.
Definition: dnnl_types.h:1860
A descriptor of a Local Response Normalization (LRN) operation.
Definition: dnnl_types.h:1796
A descriptor of a matrix multiplication operation.
Definition: dnnl_types.h:2061
Memory descriptor.
Definition: dnnl_types.h:1445
dnnl_data_type_t data_type
Data type of the tensor elements.
Definition: dnnl_types.h:1465
dnnl_dims_t dims
Dimensions in the following order:
Definition: dnnl_types.h:1462
int ndims
Number of dimensions.
Definition: dnnl_types.h:1447
An opaque structure to describe a memory.
A descriptor of a pooling operation.
Definition: dnnl_types.h:1696
A descriptor of a pooling operation.
Definition: dnnl_types.h:1734
An opaque structure for a chain of post operations.
An opaque structure for primitive descriptor attributes.
An opaque structure to describe a primitive descriptor iterator.
An opaque structure to describe a primitive descriptor.
An opaque structure to describe a primitive.
A descriptor of reduction operation.
Definition: dnnl_types.h:2111
A descriptor of resampling operation.
Definition: dnnl_types.h:2083
A descriptor for an RNN operation.
Definition: dnnl_types.h:1953
A descriptor of a shuffle operation.
Definition: dnnl_types.h:1587
A descriptor of a Softmax operation.
Definition: dnnl_types.h:1666
An opaque structure to describe an execution stream.
Structure containing version information as per Semantic Versioning
Definition: dnnl_types.h:2634