Halide 17.0.2
Halide compiler and libraries
Loading...
Searching...
No Matches
Simplify_Internal.h
Go to the documentation of this file.
1#ifndef HALIDE_SIMPLIFY_VISITORS_H
2#define HALIDE_SIMPLIFY_VISITORS_H
3
4/** \file
5 * The simplifier is separated into multiple compilation units with
6 * this single shared header to speed up the build. This file is not
7 * exported in Halide.h. */
8
9#include "Bounds.h"
10#include "IRMatch.h"
11#include "IRVisitor.h"
12#include "Scope.h"
13
14// Because this file is only included by the simplify methods and
15// doesn't go into Halide.h, we're free to use any old names for our
16// macros.
17
18#define LOG_EXPR_MUTATIONS 0
19#define LOG_STMT_MUTATIONS 0
20
21// On old compilers, some visitors would use large stack frames,
22// because they use expression templates that generate large numbers
23// of temporary objects when they are built and matched against. If we
24// wrap the expressions that imply lots of temporaries in a lambda, we
25// can get these large frames out of the recursive path.
26#define EVAL_IN_LAMBDA(x) (([&]() HALIDE_NEVER_INLINE { return (x); })())
27
28namespace Halide {
29namespace Internal {
30
32 int64_t result;
33 if (mul_with_overflow(64, a, b, &result)) {
34 return result;
35 } else if ((a > 0) == (b > 0)) {
36 return INT64_MAX;
37 } else {
38 return INT64_MIN;
39 }
40}
41
42class Simplify : public VariadicVisitor<Simplify, Expr, Stmt> {
44
45public:
47
48 struct ExprInfo {
49 // We track constant integer bounds when they exist
50 // TODO: Use ConstantInterval?
51 int64_t min = 0, max = 0;
52 bool min_defined = false, max_defined = false;
53 // And the alignment of integer variables
55
88
89 // Mix in existing knowledge about this Expr
90 void intersect(const ExprInfo &other) {
91 if (min_defined && other.min_defined) {
92 min = std::max(min, other.min);
93 } else if (other.min_defined) {
94 min_defined = true;
95 min = other.min;
96 }
97
98 if (max_defined && other.max_defined) {
99 max = std::min(max, other.max);
100 } else if (other.max_defined) {
101 max_defined = true;
102 max = other.max;
103 }
104
106
108 }
109 };
110
113 if (b) {
114 *b = ExprInfo{};
115 }
116 }
117
118#if (LOG_EXPR_MUTATIONS || LOG_STMT_MUTATIONS)
119 static int debug_indent;
120#endif
121
122#if LOG_EXPR_MUTATIONS
123 Expr mutate(const Expr &e, ExprInfo *b) {
124 const std::string spaces(debug_indent, ' ');
125 debug(1) << spaces << "Simplifying Expr: " << e << "\n";
126 debug_indent++;
127 Expr new_e = Super::dispatch(e, b);
128 debug_indent--;
129 if (!new_e.same_as(e)) {
130 debug(1)
131 << spaces << "Before: " << e << "\n"
132 << spaces << "After: " << new_e << "\n";
133 }
134 internal_assert(e.type() == new_e.type());
135 return new_e;
136 }
137
138#else
140 Expr mutate(const Expr &e, ExprInfo *b) {
141 // This gets inlined into every call to mutate, so do not add any code here.
142 return Super::dispatch(e, b);
143 }
144#endif
145
146#if LOG_STMT_MUTATIONS
147 Stmt mutate(const Stmt &s) {
148 const std::string spaces(debug_indent, ' ');
149 debug(1) << spaces << "Simplifying Stmt: " << s << "\n";
150 debug_indent++;
152 debug_indent--;
153 if (!new_s.same_as(s)) {
154 debug(1)
155 << spaces << "Before: " << s << "\n"
156 << spaces << "After: " << new_s << "\n";
157 }
158 return new_s;
159 }
160#else
161 Stmt mutate(const Stmt &s) {
162 return Super::dispatch(s);
163 }
164#endif
165
167 bool no_float_simplify = false;
168
170 bool may_simplify(const Type &t) const {
171 return !no_float_simplify || !t.is_float();
172 }
173
174 // Returns true iff t is an integral type where overflow is undefined
177 return t.is_int() && t.bits() >= 32;
178 }
179
182 return t.is_scalar() && no_overflow_int(t);
183 }
184
185 // Returns true iff t does not have a well defined overflow behavior.
188 return t.is_float() || no_overflow_int(t);
189 }
190
195
196 // Tracked for all let vars
198
199 // Only tracked for integer let vars
201
202 // Symbols used by rewrite rules
215
216 // Tracks whether or not we're inside a vector loop. Certain
217 // transformations are not a good idea if the code is to be
218 // vectorized.
219 bool in_vector_loop = false;
220
221 // Tracks whether or not the current IR is unconditionally unreachable.
222 bool in_unreachable = false;
223
224 // If we encounter a reference to a buffer (a Load, Store, Call,
225 // or Provide), there's an implicit dependence on some associated
226 // symbols.
227 void found_buffer_reference(const std::string &name, size_t dimensions = 0);
228
229 // Wrappers for as_const_foo that are more convenient to use in
230 // the large chains of conditions in the visit methods below.
231 bool const_float(const Expr &e, double *f);
232 bool const_int(const Expr &e, int64_t *i);
233 bool const_uint(const Expr &e, uint64_t *u);
234
235 // Put the args to a commutative op in a canonical order
237 bool should_commute(const Expr &a, const Expr &b) {
238 if (a.node_type() < b.node_type()) {
239 return true;
240 }
241 if (a.node_type() > b.node_type()) {
242 return false;
243 }
244
245 if (a.node_type() == IRNodeType::Variable) {
246 const Variable *va = a.as<Variable>();
247 const Variable *vb = b.as<Variable>();
248 return va->name.compare(vb->name) > 0;
249 }
250
251 return false;
252 }
253
254 std::set<Expr, IRDeepCompare> truths, falsehoods;
255
256 struct ScopedFact {
258
259 std::vector<const Variable *> pop_list;
260 std::vector<const Variable *> bounds_pop_list;
261 std::vector<Expr> truths, falsehoods;
262
263 void learn_false(const Expr &fact);
264 void learn_true(const Expr &fact);
267
268 // Replace exprs known to be truths or falsehoods with const_true or const_false.
271
273 : simplify(s) {
274 }
276
277 // allow move but not copy
278 ScopedFact(const ScopedFact &that) = delete;
280 };
281
282 // Tell the simplifier to learn from and exploit a boolean
283 // condition, over the lifetime of the returned object.
285 ScopedFact f(this);
286 f.learn_true(fact);
287 return f;
288 }
289
290 // Tell the simplifier to assume a boolean condition is false over
291 // the lifetime of the returned object.
293 ScopedFact f(this);
294 f.learn_false(fact);
295 return f;
296 }
297
299 return mutate(s);
300 }
301 Expr mutate_let_body(const Expr &e, ExprInfo *bounds) {
302 return mutate(e, bounds);
303 }
304
305 template<typename T, typename Body>
306 Body simplify_let(const T *op, ExprInfo *bounds);
307
308 Expr visit(const IntImm *op, ExprInfo *bounds);
309 Expr visit(const UIntImm *op, ExprInfo *bounds);
310 Expr visit(const FloatImm *op, ExprInfo *bounds);
311 Expr visit(const StringImm *op, ExprInfo *bounds);
312 Expr visit(const Broadcast *op, ExprInfo *bounds);
313 Expr visit(const Cast *op, ExprInfo *bounds);
314 Expr visit(const Reinterpret *op, ExprInfo *bounds);
315 Expr visit(const Variable *op, ExprInfo *bounds);
316 Expr visit(const Add *op, ExprInfo *bounds);
317 Expr visit(const Sub *op, ExprInfo *bounds);
318 Expr visit(const Mul *op, ExprInfo *bounds);
319 Expr visit(const Div *op, ExprInfo *bounds);
320 Expr visit(const Mod *op, ExprInfo *bounds);
321 Expr visit(const Min *op, ExprInfo *bounds);
322 Expr visit(const Max *op, ExprInfo *bounds);
323 Expr visit(const EQ *op, ExprInfo *bounds);
324 Expr visit(const NE *op, ExprInfo *bounds);
325 Expr visit(const LT *op, ExprInfo *bounds);
326 Expr visit(const LE *op, ExprInfo *bounds);
327 Expr visit(const GT *op, ExprInfo *bounds);
328 Expr visit(const GE *op, ExprInfo *bounds);
329 Expr visit(const And *op, ExprInfo *bounds);
330 Expr visit(const Or *op, ExprInfo *bounds);
331 Expr visit(const Not *op, ExprInfo *bounds);
332 Expr visit(const Select *op, ExprInfo *bounds);
333 Expr visit(const Ramp *op, ExprInfo *bounds);
335 Expr visit(const Load *op, ExprInfo *bounds);
336 Expr visit(const Call *op, ExprInfo *bounds);
337 Expr visit(const Shuffle *op, ExprInfo *bounds);
338 Expr visit(const VectorReduce *op, ExprInfo *bounds);
339 Expr visit(const Let *op, ExprInfo *bounds);
340 Stmt visit(const LetStmt *op);
342 Stmt visit(const For *op);
343 Stmt visit(const Provide *op);
344 Stmt visit(const Store *op);
345 Stmt visit(const Allocate *op);
346 Stmt visit(const Evaluate *op);
348 Stmt visit(const Block *op);
349 Stmt visit(const Realize *op);
350 Stmt visit(const Prefetch *op);
351 Stmt visit(const Free *op);
352 Stmt visit(const Acquire *op);
353 Stmt visit(const Fork *op);
354 Stmt visit(const Atomic *op);
356
357 std::pair<std::vector<Expr>, bool> mutate_with_changes(const std::vector<Expr> &old_exprs, ExprInfo *bounds);
358};
359
360} // namespace Internal
361} // namespace Halide
362
363#endif
Methods for computing the upper and lower bounds of an expression, and the regions of a function read...
#define internal_assert(c)
Definition Errors.h:19
#define HALIDE_ALWAYS_INLINE
Defines a method to match a fragment of IR against a pattern containing wildcards.
Defines the base class for things that recursively walk over the IR.
Defines the Scope class, which is used for keeping track of names in a scope while traversing IR.
A common pattern when traversing Halide IR is that you need to keep track of stuff when you find a Le...
Definition Scope.h:94
Stmt visit(const HoistedStorage *op)
Expr visit(const Min *op, ExprInfo *bounds)
Stmt visit(const ProducerConsumer *op)
HALIDE_ALWAYS_INLINE Expr mutate(const Expr &e, ExprInfo *b)
Scope< ExprInfo > bounds_and_alignment_info
bool const_uint(const Expr &e, uint64_t *u)
IRMatcher::WildConst< 5 > c5
void found_buffer_reference(const std::string &name, size_t dimensions=0)
Expr visit(const Cast *op, ExprInfo *bounds)
Expr visit(const LT *op, ExprInfo *bounds)
Stmt visit(const Block *op)
Expr visit(const VectorReduce *op, ExprInfo *bounds)
Stmt visit(const AssertStmt *op)
Expr visit(const UIntImm *op, ExprInfo *bounds)
HALIDE_ALWAYS_INLINE void clear_bounds_info(ExprInfo *b)
Expr visit(const Load *op, ExprInfo *bounds)
Stmt visit(const Evaluate *op)
Expr visit(const Not *op, ExprInfo *bounds)
Body simplify_let(const T *op, ExprInfo *bounds)
Simplify(bool r, const Scope< Interval > *bi, const Scope< ModulusRemainder > *ai)
Stmt visit(const Prefetch *op)
HALIDE_ALWAYS_INLINE bool no_overflow(Type t)
Expr visit(const Div *op, ExprInfo *bounds)
IRMatcher::WildConst< 1 > c1
std::pair< std::vector< Expr >, bool > mutate_with_changes(const std::vector< Expr > &old_exprs, ExprInfo *bounds)
Expr visit(const Let *op, ExprInfo *bounds)
Expr visit(const Reinterpret *op, ExprInfo *bounds)
Expr visit(const And *op, ExprInfo *bounds)
Stmt visit(const IfThenElse *op)
Expr visit(const NE *op, ExprInfo *bounds)
Expr visit(const FloatImm *op, ExprInfo *bounds)
Expr visit(const Shuffle *op, ExprInfo *bounds)
Expr visit(const Add *op, ExprInfo *bounds)
IRMatcher::WildConst< 0 > c0
ScopedFact scoped_truth(const Expr &fact)
IRMatcher::WildConst< 3 > c3
IRMatcher::WildConst< 2 > c2
Expr visit(const Ramp *op, ExprInfo *bounds)
Expr visit(const IntImm *op, ExprInfo *bounds)
Expr visit(const Max *op, ExprInfo *bounds)
Expr visit(const Variable *op, ExprInfo *bounds)
HALIDE_ALWAYS_INLINE bool may_simplify(const Type &t) const
Stmt visit(const For *op)
Stmt visit(const Atomic *op)
bool const_float(const Expr &e, double *f)
Expr visit(const GT *op, ExprInfo *bounds)
Stmt visit(const Provide *op)
Expr visit(const Sub *op, ExprInfo *bounds)
Expr visit(const LE *op, ExprInfo *bounds)
Expr visit(const Call *op, ExprInfo *bounds)
Stmt mutate_let_body(const Stmt &s, ExprInfo *)
Stmt visit(const Acquire *op)
Expr visit(const Broadcast *op, ExprInfo *bounds)
Stmt visit(const Fork *op)
HALIDE_ALWAYS_INLINE bool no_overflow_int(Type t)
Expr visit(const StringImm *op, ExprInfo *bounds)
std::set< Expr, IRDeepCompare > truths
Expr visit(const Select *op, ExprInfo *bounds)
ScopedFact scoped_falsehood(const Expr &fact)
HALIDE_ALWAYS_INLINE bool should_commute(const Expr &a, const Expr &b)
Expr visit(const Or *op, ExprInfo *bounds)
Expr visit(const Mul *op, ExprInfo *bounds)
Expr mutate_let_body(const Expr &e, ExprInfo *bounds)
Stmt visit(const Store *op)
Expr visit(const Mod *op, ExprInfo *bounds)
HALIDE_ALWAYS_INLINE bool no_overflow_scalar_int(Type t)
bool const_int(const Expr &e, int64_t *i)
Stmt visit(const Free *op)
IRMatcher::WildConst< 4 > c4
Stmt visit(const Allocate *op)
Stmt visit(const Realize *op)
Stmt visit(const LetStmt *op)
std::set< Expr, IRDeepCompare > falsehoods
Expr visit(const EQ *op, ExprInfo *bounds)
Expr visit(const GE *op, ExprInfo *bounds)
A visitor/mutator capable of passing arbitrary arguments to the visit methods using CRTP and returnin...
Definition IRVisitor.h:161
HALIDE_ALWAYS_INLINE Stmt dispatch(const Stmt &s, Args &&...args)
Definition IRVisitor.h:335
For optional debugging during codegen, use the debug class as follows:
Definition Debug.h:49
T mod_imp(T a, T b)
Implementations of division and mod that are specific to Halide.
Definition IROperator.h:239
HALIDE_MUST_USE_RESULT bool add_with_overflow(int bits, int64_t a, int64_t b, int64_t *result)
Routines to perform arithmetic on signed types without triggering signed overflow.
HALIDE_MUST_USE_RESULT bool mul_with_overflow(int bits, int64_t a, int64_t b, int64_t *result)
HALIDE_MUST_USE_RESULT bool sub_with_overflow(int bits, int64_t a, int64_t b, int64_t *result)
int64_t saturating_mul(int64_t a, int64_t b)
This file defines the class FunctionDAG, which is our representation of a Halide pipeline,...
Expr cast(Expr a)
Cast an expression to the halide type corresponding to the C++ type T.
Definition IROperator.h:364
unsigned __INT64_TYPE__ uint64_t
signed __INT64_TYPE__ int64_t
A fragment of Halide syntax.
Definition Expr.h:258
HALIDE_ALWAYS_INLINE Type type() const
Get the type of this expression node.
Definition Expr.h:322
The sum of two expressions.
Definition IR.h:56
Allocate a scratch area called with the given name, type, and size.
Definition IR.h:371
Logical and - are both expressions true.
Definition IR.h:175
If the 'condition' is false, then evaluate and return the message, which should be a call to an error...
Definition IR.h:294
Lock all the Store nodes in the body statement.
Definition IR.h:948
A sequence of statements to be executed in-order.
Definition IR.h:442
A vector with 'lanes' elements, in which every element is 'value'.
Definition IR.h:259
A function call.
Definition IR.h:490
The actual IR nodes begin here.
Definition IR.h:30
The ratio of two expressions.
Definition IR.h:83
Is the first expression equal to the second.
Definition IR.h:121
Evaluate and discard an expression, presumably because it has some side-effect.
Definition IR.h:476
Floating point constants.
Definition Expr.h:236
A for loop.
Definition IR.h:805
A pair of statements executed concurrently.
Definition IR.h:457
Free the resources associated with the given buffer.
Definition IR.h:413
Is the first expression greater than or equal to the second.
Definition IR.h:166
Is the first expression greater than the second.
Definition IR.h:157
Represents a location where storage will be hoisted to for a Func / Realize node with a given name.
Definition IR.h:932
const T * as() const
Downcast this ir node to its actual type (e.g.
Definition Expr.h:205
IRNodeType node_type() const
Definition Expr.h:212
An if-then-else block.
Definition IR.h:466
Integer constants.
Definition Expr.h:218
HALIDE_ALWAYS_INLINE bool same_as(const IntrusivePtr &other) const
Is the first expression less than or equal to the second.
Definition IR.h:148
Is the first expression less than the second.
Definition IR.h:139
A let expression, like you might find in a functional language.
Definition IR.h:271
The statement form of a let node.
Definition IR.h:282
Load a value from a named symbol if predicate is true.
Definition IR.h:217
The greater of two values.
Definition IR.h:112
The lesser of two values.
Definition IR.h:103
The remainder of a / b.
Definition IR.h:94
The result of modulus_remainder analysis.
static ModulusRemainder intersect(const ModulusRemainder &a, const ModulusRemainder &b)
The product of two expressions.
Definition IR.h:74
Is the first expression not equal to the second.
Definition IR.h:130
Logical not - true if the expression false.
Definition IR.h:193
Logical or - is at least one of the expression true.
Definition IR.h:184
Represent a multi-dimensional region of a Func or an ImageParam that needs to be prefetched.
Definition IR.h:910
This node is a helpful annotation to do with permissions.
Definition IR.h:315
This defines the value of a function at a multi-dimensional location.
Definition IR.h:354
A linear ramp vector node.
Definition IR.h:247
Allocate a multi-dimensional buffer of the given type and size.
Definition IR.h:427
Reinterpret value as another type, without affecting any of the bits (on little-endian systems).
Definition IR.h:47
A ternary operator.
Definition IR.h:204
Construct a new vector by taking elements from another sequence of vectors.
Definition IR.h:841
void intersect(const ExprInfo &other)
ScopedFact(ScopedFact &&that)=default
void learn_false(const Expr &fact)
std::vector< const Variable * > bounds_pop_list
ScopedFact(const ScopedFact &that)=delete
std::vector< const Variable * > pop_list
void learn_lower_bound(const Variable *v, int64_t val)
void learn_upper_bound(const Variable *v, int64_t val)
A reference-counted handle to a statement node.
Definition Expr.h:419
Store a 'value' to the buffer called 'name' at a given 'index' if 'predicate' is true.
Definition IR.h:333
String constants.
Definition Expr.h:245
The difference of two expressions.
Definition IR.h:65
Unsigned integer constants.
Definition Expr.h:227
A named variable.
Definition IR.h:758
Horizontally reduce a vector to a scalar or narrower vector using the given commutative and associati...
Definition IR.h:966
Types in the halide type system.
Definition Type.h:276
HALIDE_ALWAYS_INLINE bool is_int() const
Is this type a signed integer type?
Definition Type.h:428
HALIDE_ALWAYS_INLINE int bits() const
Return the bit size of a single element of this type.
Definition Type.h:342
HALIDE_ALWAYS_INLINE bool is_scalar() const
Is this type a scalar type? (lanes() == 1).
Definition Type.h:410
HALIDE_ALWAYS_INLINE bool is_float() const
Is this type a floating point type (float or double).
Definition Type.h:416